Repository: davisking/dlib Branch: master Commit: 0828f313d422 Files: 1964 Total size: 24.2 MB Directory structure: gitextract_o0iva_dz/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug.yml │ │ ├── config.yml │ │ └── feature.yml │ └── workflows/ │ ├── build_cpp.yml │ ├── build_matlab.yml │ └── build_python.yml ├── .gitignore ├── CMakeLists.txt ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── dlib/ │ ├── CMakeLists.txt │ ├── LICENSE.txt │ ├── algs.h │ ├── all/ │ │ └── source.cpp │ ├── any/ │ │ ├── any.h │ │ ├── any_abstract.h │ │ ├── any_decision_function.h │ │ ├── any_decision_function_abstract.h │ │ ├── any_function.h │ │ ├── any_function_abstract.h │ │ ├── any_trainer.h │ │ ├── any_trainer_abstract.h │ │ └── storage.h │ ├── any.h │ ├── array/ │ │ ├── array_kernel.h │ │ ├── array_kernel_abstract.h │ │ ├── array_tools.h │ │ └── array_tools_abstract.h │ ├── array.h │ ├── array2d/ │ │ ├── array2d_generic_image.h │ │ ├── array2d_kernel.h │ │ ├── array2d_kernel_abstract.h │ │ └── serialize_pixel_overloads.h │ ├── array2d.h │ ├── assert.h │ ├── base64/ │ │ ├── base64_kernel_1.cpp │ │ ├── base64_kernel_1.h │ │ └── base64_kernel_abstract.h │ ├── base64.h │ ├── bayes_utils/ │ │ ├── bayes_utils.h │ │ └── bayes_utils_abstract.h │ ├── bayes_utils.h │ ├── bigint/ │ │ ├── bigint_kernel_1.cpp │ │ ├── bigint_kernel_1.h │ │ ├── bigint_kernel_2.cpp │ │ ├── bigint_kernel_2.h │ │ ├── bigint_kernel_abstract.h │ │ └── bigint_kernel_c.h │ ├── bigint.h │ ├── binary_search_tree/ │ │ ├── binary_search_tree_kernel_1.h │ │ ├── binary_search_tree_kernel_2.h │ │ ├── binary_search_tree_kernel_abstract.h │ │ └── binary_search_tree_kernel_c.h │ ├── binary_search_tree.h │ ├── bit_stream/ │ │ ├── bit_stream_kernel_1.cpp │ │ ├── bit_stream_kernel_1.h │ │ ├── bit_stream_kernel_abstract.h │ │ ├── bit_stream_kernel_c.h │ │ ├── bit_stream_multi_1.h │ │ ├── bit_stream_multi_abstract.h │ │ └── bit_stream_multi_c.h │ ├── bit_stream.h │ ├── bits/ │ │ └── c++config.h │ ├── bound_function_pointer/ │ │ ├── bound_function_pointer_kernel_1.h │ │ └── bound_function_pointer_kernel_abstract.h │ ├── bound_function_pointer.h │ ├── bridge/ │ │ ├── bridge.h │ │ └── bridge_abstract.h │ ├── bridge.h │ ├── bsp/ │ │ ├── bsp.cpp │ │ ├── bsp.h │ │ └── bsp_abstract.h │ ├── bsp.h │ ├── byte_orderer/ │ │ ├── byte_orderer_kernel_1.h │ │ └── byte_orderer_kernel_abstract.h │ ├── byte_orderer.h │ ├── cassert │ ├── clustering/ │ │ ├── bottom_up_cluster.h │ │ ├── bottom_up_cluster_abstract.h │ │ ├── chinese_whispers.h │ │ ├── chinese_whispers_abstract.h │ │ ├── modularity_clustering.h │ │ ├── modularity_clustering_abstract.h │ │ ├── spectral_cluster.h │ │ └── spectral_cluster_abstract.h │ ├── clustering.h │ ├── cmake │ ├── cmake_utils/ │ │ ├── FindCUDNN.cmake │ │ ├── check_if_avx_instructions_executable_on_host.cmake │ │ ├── check_if_neon_available.cmake │ │ ├── check_if_sse4_instructions_executable_on_host.cmake │ │ ├── dlib.pc.in │ │ ├── dlibConfig.cmake.in │ │ ├── find_blas.cmake │ │ ├── find_ffmpeg.cmake │ │ ├── find_libjpeg.cmake │ │ ├── find_libjxl.cmake │ │ ├── find_libpng.cmake │ │ ├── find_libwebp.cmake │ │ ├── release_build_by_default │ │ ├── set_compiler_specific_options.cmake │ │ ├── tell_visual_studio_to_use_static_runtime.cmake │ │ ├── test_for_avx/ │ │ │ ├── CMakeLists.txt │ │ │ ├── avx_test.cpp │ │ │ └── this_file_doesnt_compile.cpp │ │ ├── test_for_libjpeg/ │ │ │ ├── CMakeLists.txt │ │ │ └── libjpeg_test.cpp │ │ ├── test_for_libjxl/ │ │ │ ├── CMakeLists.txt │ │ │ └── libjxl_test.cpp │ │ ├── test_for_libpng/ │ │ │ ├── CMakeLists.txt │ │ │ └── libpng_test.cpp │ │ ├── test_for_libwebp/ │ │ │ ├── CMakeLists.txt │ │ │ └── libwebp_test.cpp │ │ ├── test_for_neon/ │ │ │ ├── CMakeLists.txt │ │ │ └── neon_test.cpp │ │ └── test_for_sse4/ │ │ ├── CMakeLists.txt │ │ ├── sse4_test.cpp │ │ └── this_file_doesnt_compile.cpp │ ├── cmd_line_parser/ │ │ ├── cmd_line_parser_check_1.h │ │ ├── cmd_line_parser_check_c.h │ │ ├── cmd_line_parser_kernel_1.h │ │ ├── cmd_line_parser_kernel_abstract.h │ │ ├── cmd_line_parser_kernel_c.h │ │ ├── cmd_line_parser_print_1.h │ │ ├── get_option.h │ │ └── get_option_abstract.h │ ├── cmd_line_parser.h │ ├── compress_stream/ │ │ ├── compress_stream_kernel_1.h │ │ ├── compress_stream_kernel_2.h │ │ ├── compress_stream_kernel_3.h │ │ └── compress_stream_kernel_abstract.h │ ├── compress_stream.h │ ├── conditioning_class/ │ │ ├── conditioning_class_kernel_1.h │ │ ├── conditioning_class_kernel_2.h │ │ ├── conditioning_class_kernel_3.h │ │ ├── conditioning_class_kernel_4.h │ │ ├── conditioning_class_kernel_abstract.h │ │ └── conditioning_class_kernel_c.h │ ├── conditioning_class.h │ ├── config.h │ ├── config.h.in │ ├── config_reader/ │ │ ├── config_reader_kernel_1.h │ │ ├── config_reader_kernel_abstract.h │ │ ├── config_reader_thread_safe_1.h │ │ └── config_reader_thread_safe_abstract.h │ ├── config_reader.h │ ├── console_progress_indicator.h │ ├── constexpr_if.h │ ├── control/ │ │ ├── approximate_linear_models.h │ │ ├── approximate_linear_models_abstract.h │ │ ├── lspi.h │ │ ├── lspi_abstract.h │ │ ├── mpc.h │ │ └── mpc_abstract.h │ ├── control.h │ ├── cpp_pretty_printer/ │ │ ├── cpp_pretty_printer_kernel_1.h │ │ ├── cpp_pretty_printer_kernel_2.h │ │ └── cpp_pretty_printer_kernel_abstract.h │ ├── cpp_pretty_printer.h │ ├── cpp_tokenizer/ │ │ ├── cpp_tokenizer_kernel_1.h │ │ ├── cpp_tokenizer_kernel_abstract.h │ │ └── cpp_tokenizer_kernel_c.h │ ├── cpp_tokenizer.h │ ├── crc32/ │ │ ├── crc32_kernel_1.h │ │ └── crc32_kernel_abstract.h │ ├── crc32.h │ ├── cstring │ ├── cuda/ │ │ ├── cpu_dlib.cpp │ │ ├── cpu_dlib.h │ │ ├── cublas_dlibapi.cpp │ │ ├── cublas_dlibapi.h │ │ ├── cuda_data_ptr.cpp │ │ ├── cuda_data_ptr.h │ │ ├── cuda_dlib.cu │ │ ├── cuda_dlib.h │ │ ├── cuda_errors.h │ │ ├── cuda_utils.h │ │ ├── cudnn_dlibapi.cpp │ │ ├── cudnn_dlibapi.h │ │ ├── curand_dlibapi.cpp │ │ ├── curand_dlibapi.h │ │ ├── cusolver_dlibapi.cu │ │ ├── cusolver_dlibapi.h │ │ ├── gpu_data.cpp │ │ ├── gpu_data.h │ │ ├── gpu_data_abstract.h │ │ ├── operation_mode.h │ │ ├── tensor.h │ │ ├── tensor_abstract.h │ │ ├── tensor_tools.cpp │ │ └── tensor_tools.h │ ├── data_io/ │ │ ├── arc_agi.h │ │ ├── arc_agi_abstract.h │ │ ├── cifar.cpp │ │ ├── cifar.h │ │ ├── cifar_abstract.h │ │ ├── image_dataset_metadata.cpp │ │ ├── image_dataset_metadata.h │ │ ├── libsvm_io.h │ │ ├── libsvm_io_abstract.h │ │ ├── load_image_dataset.h │ │ ├── load_image_dataset_abstract.h │ │ ├── mnist.cpp │ │ ├── mnist.h │ │ └── mnist_abstract.h │ ├── data_io.h │ ├── dir_nav/ │ │ ├── dir_nav_extensions.cpp │ │ ├── dir_nav_extensions.h │ │ ├── dir_nav_extensions_abstract.h │ │ ├── dir_nav_kernel_1.cpp │ │ ├── dir_nav_kernel_1.h │ │ ├── dir_nav_kernel_2.cpp │ │ ├── dir_nav_kernel_2.h │ │ ├── dir_nav_kernel_abstract.h │ │ ├── posix.h │ │ └── windows.h │ ├── dir_nav.h │ ├── directed_graph/ │ │ ├── directed_graph_kernel_1.h │ │ └── directed_graph_kernel_abstract.h │ ├── directed_graph.h │ ├── disjoint_subsets/ │ │ ├── disjoint_subsets.h │ │ ├── disjoint_subsets_abstract.h │ │ ├── disjoint_subsets_sized.h │ │ └── disjoint_subsets_sized_abstract.h │ ├── disjoint_subsets.h │ ├── dlib_basic_cpp_build_tutorial.txt │ ├── dlib_include_path_tutorial.txt │ ├── dnn/ │ │ ├── core.h │ │ ├── core_abstract.h │ │ ├── input.h │ │ ├── input_abstract.h │ │ ├── layers.h │ │ ├── layers_abstract.h │ │ ├── loss.h │ │ ├── loss_abstract.h │ │ ├── solvers.h │ │ ├── solvers_abstract.h │ │ ├── trainer.h │ │ ├── trainer_abstract.h │ │ ├── utilities.h │ │ ├── utilities_abstract.h │ │ ├── validation.h │ │ ├── visitors.h │ │ └── visitors_abstract.h │ ├── dnn.h │ ├── enable_if.h │ ├── entropy_decoder/ │ │ ├── entropy_decoder_kernel_1.cpp │ │ ├── entropy_decoder_kernel_1.h │ │ ├── entropy_decoder_kernel_2.cpp │ │ ├── entropy_decoder_kernel_2.h │ │ ├── entropy_decoder_kernel_abstract.h │ │ └── entropy_decoder_kernel_c.h │ ├── entropy_decoder.h │ ├── entropy_decoder_model/ │ │ ├── entropy_decoder_model_kernel_1.h │ │ ├── entropy_decoder_model_kernel_2.h │ │ ├── entropy_decoder_model_kernel_3.h │ │ ├── entropy_decoder_model_kernel_4.h │ │ ├── entropy_decoder_model_kernel_5.h │ │ ├── entropy_decoder_model_kernel_6.h │ │ └── entropy_decoder_model_kernel_abstract.h │ ├── entropy_decoder_model.h │ ├── entropy_encoder/ │ │ ├── entropy_encoder_kernel_1.cpp │ │ ├── entropy_encoder_kernel_1.h │ │ ├── entropy_encoder_kernel_2.cpp │ │ ├── entropy_encoder_kernel_2.h │ │ ├── entropy_encoder_kernel_abstract.h │ │ └── entropy_encoder_kernel_c.h │ ├── entropy_encoder.h │ ├── entropy_encoder_model/ │ │ ├── entropy_encoder_model_kernel_1.h │ │ ├── entropy_encoder_model_kernel_2.h │ │ ├── entropy_encoder_model_kernel_3.h │ │ ├── entropy_encoder_model_kernel_4.h │ │ ├── entropy_encoder_model_kernel_5.h │ │ ├── entropy_encoder_model_kernel_6.h │ │ ├── entropy_encoder_model_kernel_abstract.h │ │ └── entropy_encoder_model_kernel_c.h │ ├── entropy_encoder_model.h │ ├── error.h │ ├── external/ │ │ ├── cblas/ │ │ │ ├── CMakeLists.txt │ │ │ ├── README │ │ │ ├── cblas.h │ │ │ ├── cblas_caxpy.c │ │ │ ├── cblas_ccopy.c │ │ │ ├── cblas_cdotc_sub.c │ │ │ ├── cblas_cdotu_sub.c │ │ │ ├── cblas_cgbmv.c │ │ │ ├── cblas_cgemm.c │ │ │ ├── cblas_cgemv.c │ │ │ ├── cblas_cgerc.c │ │ │ ├── cblas_cgeru.c │ │ │ ├── cblas_chbmv.c │ │ │ ├── cblas_chemm.c │ │ │ ├── cblas_chemv.c │ │ │ ├── cblas_cher.c │ │ │ ├── cblas_cher2.c │ │ │ ├── cblas_cher2k.c │ │ │ ├── cblas_cherk.c │ │ │ ├── cblas_chpmv.c │ │ │ ├── cblas_chpr.c │ │ │ ├── cblas_chpr2.c │ │ │ ├── cblas_cscal.c │ │ │ ├── cblas_csscal.c │ │ │ ├── cblas_cswap.c │ │ │ ├── cblas_csymm.c │ │ │ ├── cblas_csyr2k.c │ │ │ ├── cblas_csyrk.c │ │ │ ├── cblas_ctbmv.c │ │ │ ├── cblas_ctbsv.c │ │ │ ├── cblas_ctpmv.c │ │ │ ├── cblas_ctpsv.c │ │ │ ├── cblas_ctrmm.c │ │ │ ├── cblas_ctrmv.c │ │ │ ├── cblas_ctrsm.c │ │ │ ├── cblas_ctrsv.c │ │ │ ├── cblas_dasum.c │ │ │ ├── cblas_daxpy.c │ │ │ ├── cblas_dcopy.c │ │ │ ├── cblas_ddot.c │ │ │ ├── cblas_dgbmv.c │ │ │ ├── cblas_dgemm.c │ │ │ ├── cblas_dgemv.c │ │ │ ├── cblas_dger.c │ │ │ ├── cblas_dnrm2.c │ │ │ ├── cblas_drot.c │ │ │ ├── cblas_drotg.c │ │ │ ├── cblas_drotm.c │ │ │ ├── cblas_drotmg.c │ │ │ ├── cblas_dsbmv.c │ │ │ ├── cblas_dscal.c │ │ │ ├── cblas_dsdot.c │ │ │ ├── cblas_dspmv.c │ │ │ ├── cblas_dspr.c │ │ │ ├── cblas_dspr2.c │ │ │ ├── cblas_dswap.c │ │ │ ├── cblas_dsymm.c │ │ │ ├── cblas_dsymv.c │ │ │ ├── cblas_dsyr.c │ │ │ ├── cblas_dsyr2.c │ │ │ ├── cblas_dsyr2k.c │ │ │ ├── cblas_dsyrk.c │ │ │ ├── cblas_dtbmv.c │ │ │ ├── cblas_dtbsv.c │ │ │ ├── cblas_dtpmv.c │ │ │ ├── cblas_dtpsv.c │ │ │ ├── cblas_dtrmm.c │ │ │ ├── cblas_dtrmv.c │ │ │ ├── cblas_dtrsm.c │ │ │ ├── cblas_dtrsv.c │ │ │ ├── cblas_dzasum.c │ │ │ ├── cblas_dznrm2.c │ │ │ ├── cblas_f77.h │ │ │ ├── cblas_icamax.c │ │ │ ├── cblas_idamax.c │ │ │ ├── cblas_isamax.c │ │ │ ├── cblas_izamax.c │ │ │ ├── cblas_sasum.c │ │ │ ├── cblas_saxpy.c │ │ │ ├── cblas_scasum.c │ │ │ ├── cblas_scnrm2.c │ │ │ ├── cblas_scopy.c │ │ │ ├── cblas_sdot.c │ │ │ ├── cblas_sdsdot.c │ │ │ ├── cblas_sgbmv.c │ │ │ ├── cblas_sgemm.c │ │ │ ├── cblas_sgemv.c │ │ │ ├── cblas_sger.c │ │ │ ├── cblas_snrm2.c │ │ │ ├── cblas_srot.c │ │ │ ├── cblas_srotg.c │ │ │ ├── cblas_srotm.c │ │ │ ├── cblas_srotmg.c │ │ │ ├── cblas_ssbmv.c │ │ │ ├── cblas_sscal.c │ │ │ ├── cblas_sspmv.c │ │ │ ├── cblas_sspr.c │ │ │ ├── cblas_sspr2.c │ │ │ ├── cblas_sswap.c │ │ │ ├── cblas_ssymm.c │ │ │ ├── cblas_ssymv.c │ │ │ ├── cblas_ssyr.c │ │ │ ├── cblas_ssyr2.c │ │ │ ├── cblas_ssyr2k.c │ │ │ ├── cblas_ssyrk.c │ │ │ ├── cblas_stbmv.c │ │ │ ├── cblas_stbsv.c │ │ │ ├── cblas_stpmv.c │ │ │ ├── cblas_stpsv.c │ │ │ ├── cblas_strmm.c │ │ │ ├── cblas_strmv.c │ │ │ ├── cblas_strsm.c │ │ │ ├── cblas_strsv.c │ │ │ ├── cblas_xerbla.c │ │ │ ├── cblas_zaxpy.c │ │ │ ├── cblas_zcopy.c │ │ │ ├── cblas_zdotc_sub.c │ │ │ ├── cblas_zdotu_sub.c │ │ │ ├── cblas_zdscal.c │ │ │ ├── cblas_zgbmv.c │ │ │ ├── cblas_zgemm.c │ │ │ ├── cblas_zgemv.c │ │ │ ├── cblas_zgerc.c │ │ │ ├── cblas_zgeru.c │ │ │ ├── cblas_zhbmv.c │ │ │ ├── cblas_zhemm.c │ │ │ ├── cblas_zhemv.c │ │ │ ├── cblas_zher.c │ │ │ ├── cblas_zher2.c │ │ │ ├── cblas_zher2k.c │ │ │ ├── cblas_zherk.c │ │ │ ├── cblas_zhpmv.c │ │ │ ├── cblas_zhpr.c │ │ │ ├── cblas_zhpr2.c │ │ │ ├── cblas_zscal.c │ │ │ ├── cblas_zswap.c │ │ │ ├── cblas_zsymm.c │ │ │ ├── cblas_zsyr2k.c │ │ │ ├── cblas_zsyrk.c │ │ │ ├── cblas_ztbmv.c │ │ │ ├── cblas_ztbsv.c │ │ │ ├── cblas_ztpmv.c │ │ │ ├── cblas_ztpsv.c │ │ │ ├── cblas_ztrmm.c │ │ │ ├── cblas_ztrmv.c │ │ │ ├── cblas_ztrsm.c │ │ │ ├── cblas_ztrsv.c │ │ │ ├── cdotcsub.f │ │ │ ├── cdotusub.f │ │ │ ├── dasumsub.f │ │ │ ├── ddotsub.f │ │ │ ├── dnrm2sub.f │ │ │ ├── dsdotsub.f │ │ │ ├── dzasumsub.f │ │ │ ├── dznrm2sub.f │ │ │ ├── icamaxsub.f │ │ │ ├── idamaxsub.f │ │ │ ├── isamaxsub.f │ │ │ ├── izamaxsub.f │ │ │ ├── sasumsub.f │ │ │ ├── scasumsub.f │ │ │ ├── scnrm2sub.f │ │ │ ├── sdotsub.f │ │ │ ├── sdsdotsub.f │ │ │ ├── snrm2sub.f │ │ │ ├── zdotcsub.f │ │ │ └── zdotusub.f │ │ ├── libjpeg/ │ │ │ ├── README │ │ │ ├── cderror.h │ │ │ ├── cdjpeg.h │ │ │ ├── jaricom.c │ │ │ ├── jcapimin.c │ │ │ ├── jcapistd.c │ │ │ ├── jcarith.c │ │ │ ├── jccoefct.c │ │ │ ├── jccolor.c │ │ │ ├── jcdctmgr.c │ │ │ ├── jchuff.c │ │ │ ├── jcinit.c │ │ │ ├── jcmainct.c │ │ │ ├── jcmarker.c │ │ │ ├── jcmaster.c │ │ │ ├── jcomapi.c │ │ │ ├── jconfig.h │ │ │ ├── jcparam.c │ │ │ ├── jcprepct.c │ │ │ ├── jcsample.c │ │ │ ├── jctrans.c │ │ │ ├── jdapimin.c │ │ │ ├── jdapistd.c │ │ │ ├── jdarith.c │ │ │ ├── jdatadst.c │ │ │ ├── jdatasrc.c │ │ │ ├── jdcoefct.c │ │ │ ├── jdcolor.c │ │ │ ├── jdct.h │ │ │ ├── jddctmgr.c │ │ │ ├── jdhuff.c │ │ │ ├── jdinput.c │ │ │ ├── jdmainct.c │ │ │ ├── jdmarker.c │ │ │ ├── jdmaster.c │ │ │ ├── jdmerge.c │ │ │ ├── jdpostct.c │ │ │ ├── jdsample.c │ │ │ ├── jdtrans.c │ │ │ ├── jerror.c │ │ │ ├── jerror.h │ │ │ ├── jfdctflt.c │ │ │ ├── jfdctfst.c │ │ │ ├── jfdctint.c │ │ │ ├── jidctflt.c │ │ │ ├── jidctfst.c │ │ │ ├── jidctint.c │ │ │ ├── jinclude.h │ │ │ ├── jmemansi.c │ │ │ ├── jmemmgr.c │ │ │ ├── jmemname.c │ │ │ ├── jmemnobs.c │ │ │ ├── jmemsys.h │ │ │ ├── jmorecfg.h │ │ │ ├── jpegint.h │ │ │ ├── jpeglib.h │ │ │ ├── jpegtran.c │ │ │ ├── jquant1.c │ │ │ ├── jquant2.c │ │ │ ├── jutils.c │ │ │ ├── jversion.h │ │ │ ├── rdbmp.c │ │ │ ├── rdcolmap.c │ │ │ ├── rdgif.c │ │ │ ├── rdjpgcom.c │ │ │ ├── rdppm.c │ │ │ ├── rdrle.c │ │ │ ├── rdswitch.c │ │ │ ├── rdtarga.c │ │ │ ├── transupp.c │ │ │ ├── transupp.h │ │ │ ├── wrbmp.c │ │ │ ├── wrgif.c │ │ │ ├── wrjpgcom.c │ │ │ ├── wrppm.c │ │ │ ├── wrrle.c │ │ │ └── wrtarga.c │ │ ├── libpng/ │ │ │ ├── LICENSE │ │ │ ├── README │ │ │ ├── arm/ │ │ │ │ ├── arm_init.c │ │ │ │ ├── filter_neon.S │ │ │ │ ├── filter_neon_intrinsics.c │ │ │ │ └── palette_neon_intrinsics.c │ │ │ ├── png.c │ │ │ ├── png.h │ │ │ ├── pngconf.h │ │ │ ├── pngdebug.h │ │ │ ├── pngerror.c │ │ │ ├── pngget.c │ │ │ ├── pnginfo.h │ │ │ ├── pnglibconf.h │ │ │ ├── pngmem.c │ │ │ ├── pngpread.c │ │ │ ├── pngpriv.h │ │ │ ├── pngread.c │ │ │ ├── pngrio.c │ │ │ ├── pngrtran.c │ │ │ ├── pngrutil.c │ │ │ ├── pngset.c │ │ │ ├── pngstruct.h │ │ │ ├── pngtrans.c │ │ │ ├── pngwio.c │ │ │ ├── pngwrite.c │ │ │ ├── pngwtran.c │ │ │ └── pngwutil.c │ │ ├── pybind11/ │ │ │ ├── CMakeLists.txt │ │ │ ├── LICENSE │ │ │ ├── README.rst │ │ │ ├── include/ │ │ │ │ └── pybind11/ │ │ │ │ ├── attr.h │ │ │ │ ├── buffer_info.h │ │ │ │ ├── cast.h │ │ │ │ ├── chrono.h │ │ │ │ ├── common.h │ │ │ │ ├── complex.h │ │ │ │ ├── detail/ │ │ │ │ │ ├── class.h │ │ │ │ │ ├── common.h │ │ │ │ │ ├── descr.h │ │ │ │ │ ├── init.h │ │ │ │ │ ├── internals.h │ │ │ │ │ ├── type_caster_base.h │ │ │ │ │ └── typeid.h │ │ │ │ ├── eigen/ │ │ │ │ │ ├── common.h │ │ │ │ │ ├── matrix.h │ │ │ │ │ └── tensor.h │ │ │ │ ├── eigen.h │ │ │ │ ├── embed.h │ │ │ │ ├── eval.h │ │ │ │ ├── functional.h │ │ │ │ ├── gil.h │ │ │ │ ├── gil_safe_call_once.h │ │ │ │ ├── iostream.h │ │ │ │ ├── numpy.h │ │ │ │ ├── operators.h │ │ │ │ ├── options.h │ │ │ │ ├── pybind11.h │ │ │ │ ├── pytypes.h │ │ │ │ ├── stl/ │ │ │ │ │ └── filesystem.h │ │ │ │ ├── stl.h │ │ │ │ ├── stl_bind.h │ │ │ │ ├── type_caster_pyobject_ptr.h │ │ │ │ └── typing.h │ │ │ └── tools/ │ │ │ ├── FindCatch.cmake │ │ │ ├── FindEigen3.cmake │ │ │ ├── FindPythonLibsNew.cmake │ │ │ ├── JoinPaths.cmake │ │ │ ├── check-style.sh │ │ │ ├── cmake_uninstall.cmake.in │ │ │ ├── codespell_ignore_lines_from_errors.py │ │ │ ├── libsize.py │ │ │ ├── make_changelog.py │ │ │ ├── pybind11.pc.in │ │ │ ├── pybind11Common.cmake │ │ │ ├── pybind11Config.cmake.in │ │ │ ├── pybind11NewTools.cmake │ │ │ ├── pybind11Tools.cmake │ │ │ ├── pyproject.toml │ │ │ ├── setup_global.py.in │ │ │ └── setup_main.py.in │ │ └── zlib/ │ │ ├── README │ │ ├── adler32.c │ │ ├── compress.c │ │ ├── crc32.c │ │ ├── crc32.h │ │ ├── deflate.c │ │ ├── deflate.h │ │ ├── gzclose.c │ │ ├── gzguts.h │ │ ├── gzlib.c │ │ ├── gzread.c │ │ ├── gzwrite.c │ │ ├── infback.c │ │ ├── inffast.c │ │ ├── inffast.h │ │ ├── inffixed.h │ │ ├── inflate.c │ │ ├── inflate.h │ │ ├── inftrees.c │ │ ├── inftrees.h │ │ ├── trees.c │ │ ├── trees.h │ │ ├── uncompr.c │ │ ├── zconf.h │ │ ├── zlib.h │ │ ├── zutil.c │ │ └── zutil.h │ ├── fft/ │ │ ├── fft.cpp │ │ ├── fft.h │ │ ├── fft_size.h │ │ ├── fft_stl.h │ │ ├── kiss_fft.h │ │ └── mkl_fft.h │ ├── filtering/ │ │ ├── kalman_filter.cpp │ │ ├── kalman_filter.h │ │ ├── kalman_filter_abstract.h │ │ ├── rls_filter.h │ │ └── rls_filter_abstract.h │ ├── filtering.h │ ├── float_details.h │ ├── fstream │ ├── functional.h │ ├── general_hash/ │ │ ├── count_bits.h │ │ ├── count_bits_abstract.h │ │ ├── general_hash.h │ │ ├── hash.h │ │ ├── hash_abstract.h │ │ ├── murmur_hash3.h │ │ ├── murmur_hash3_abstract.h │ │ ├── random_hashing.h │ │ └── random_hashing_abstract.h │ ├── geometry/ │ │ ├── border_enumerator.h │ │ ├── border_enumerator_abstract.h │ │ ├── drectangle.h │ │ ├── drectangle_abstract.h │ │ ├── line.h │ │ ├── line_abstract.h │ │ ├── point_transforms.h │ │ ├── point_transforms_abstract.h │ │ ├── polygon.h │ │ ├── polygon_abstract.h │ │ ├── rectangle.h │ │ ├── rectangle_abstract.h │ │ ├── vector.h │ │ └── vector_abstract.h │ ├── geometry.h │ ├── global_optimization/ │ │ ├── find_max_global.h │ │ ├── find_max_global_abstract.h │ │ ├── global_function_search.cpp │ │ ├── global_function_search.h │ │ ├── global_function_search_abstract.h │ │ ├── upper_bound_function.h │ │ └── upper_bound_function_abstract.h │ ├── global_optimization.h │ ├── graph/ │ │ ├── graph_kernel_1.h │ │ └── graph_kernel_abstract.h │ ├── graph.h │ ├── graph_cuts/ │ │ ├── find_max_factor_graph_potts.h │ │ ├── find_max_factor_graph_potts_abstract.h │ │ ├── general_flow_graph.h │ │ ├── general_potts_problem.h │ │ ├── graph_labeler.h │ │ ├── graph_labeler_abstract.h │ │ ├── min_cut.h │ │ └── min_cut_abstract.h │ ├── graph_cuts.h │ ├── graph_utils/ │ │ ├── edge_list_graphs.h │ │ ├── edge_list_graphs_abstract.h │ │ ├── find_k_nearest_neighbors_lsh.h │ │ ├── find_k_nearest_neighbors_lsh_abstract.h │ │ ├── function_objects.h │ │ ├── function_objects_abstract.h │ │ ├── graph_utils.h │ │ ├── graph_utils_abstract.h │ │ ├── ordered_sample_pair.h │ │ ├── ordered_sample_pair_abstract.h │ │ ├── sample_pair.h │ │ └── sample_pair_abstract.h │ ├── graph_utils.h │ ├── graph_utils_threaded.h │ ├── gui_core/ │ │ ├── gui_core_kernel_1.cpp │ │ ├── gui_core_kernel_1.h │ │ ├── gui_core_kernel_2.cpp │ │ ├── gui_core_kernel_2.h │ │ ├── gui_core_kernel_abstract.h │ │ ├── windows.h │ │ └── xlib.h │ ├── gui_core.h │ ├── gui_widgets/ │ │ ├── base_widgets.cpp │ │ ├── base_widgets.h │ │ ├── base_widgets_abstract.h │ │ ├── canvas_drawing.cpp │ │ ├── canvas_drawing.h │ │ ├── canvas_drawing_abstract.h │ │ ├── drawable.cpp │ │ ├── drawable.h │ │ ├── drawable_abstract.h │ │ ├── fonts.cpp │ │ ├── fonts.h │ │ ├── fonts_abstract.h │ │ ├── nativefont.h │ │ ├── style.cpp │ │ ├── style.h │ │ ├── style_abstract.h │ │ ├── widgets.cpp │ │ ├── widgets.h │ │ └── widgets_abstract.h │ ├── gui_widgets.h │ ├── hash.h │ ├── hash_map/ │ │ ├── hash_map_kernel_1.h │ │ ├── hash_map_kernel_abstract.h │ │ └── hash_map_kernel_c.h │ ├── hash_map.h │ ├── hash_set/ │ │ ├── hash_set_kernel_1.h │ │ ├── hash_set_kernel_abstract.h │ │ └── hash_set_kernel_c.h │ ├── hash_set.h │ ├── hash_table/ │ │ ├── hash_table_kernel_1.h │ │ ├── hash_table_kernel_2.h │ │ ├── hash_table_kernel_abstract.h │ │ └── hash_table_kernel_c.h │ ├── hash_table.h │ ├── http_client/ │ │ ├── http_client.cpp │ │ ├── http_client.h │ │ └── http_client_abstract.h │ ├── image_io.h │ ├── image_keypoint/ │ │ ├── binned_vector_feature_image.h │ │ ├── binned_vector_feature_image_abstract.h │ │ ├── build_separable_poly_filters.h │ │ ├── draw_surf_points.h │ │ ├── draw_surf_points_abstract.h │ │ ├── fine_hog_image.h │ │ ├── fine_hog_image_abstract.h │ │ ├── hashed_feature_image.h │ │ ├── hashed_feature_image_abstract.h │ │ ├── hessian_pyramid.h │ │ ├── hessian_pyramid_abstract.h │ │ ├── hog.h │ │ ├── hog_abstract.h │ │ ├── nearest_neighbor_feature_image.h │ │ ├── nearest_neighbor_feature_image_abstract.h │ │ ├── poly_image.h │ │ ├── poly_image_abstract.h │ │ ├── surf.h │ │ └── surf_abstract.h │ ├── image_keypoint.h │ ├── image_loader/ │ │ ├── image_loader.h │ │ ├── image_loader_abstract.h │ │ ├── jpeg_loader.cpp │ │ ├── jpeg_loader.h │ │ ├── jpeg_loader_abstract.h │ │ ├── jxl_loader.cpp │ │ ├── jxl_loader.h │ │ ├── jxl_loader_abstract.h │ │ ├── load_image.h │ │ ├── load_image_abstract.h │ │ ├── png_loader.cpp │ │ ├── png_loader.h │ │ ├── png_loader_abstract.h │ │ ├── webp_loader.cpp │ │ ├── webp_loader.h │ │ └── webp_loader_abstract.h │ ├── image_processing/ │ │ ├── box_overlap_testing.h │ │ ├── box_overlap_testing_abstract.h │ │ ├── correlation_tracker.h │ │ ├── correlation_tracker_abstract.h │ │ ├── detection_template_tools.h │ │ ├── detection_template_tools_abstract.h │ │ ├── frontal_face_detector.h │ │ ├── frontal_face_detector_abstract.h │ │ ├── full_object_detection.h │ │ ├── full_object_detection_abstract.h │ │ ├── generic_image.h │ │ ├── object_detector.h │ │ ├── object_detector_abstract.h │ │ ├── remove_unobtainable_rectangles.h │ │ ├── remove_unobtainable_rectangles_abstract.h │ │ ├── render_face_detections.h │ │ ├── render_face_detections_abstract.h │ │ ├── scan_fhog_pyramid.h │ │ ├── scan_fhog_pyramid_abstract.h │ │ ├── scan_image.h │ │ ├── scan_image_abstract.h │ │ ├── scan_image_boxes.h │ │ ├── scan_image_boxes_abstract.h │ │ ├── scan_image_custom.h │ │ ├── scan_image_custom_abstract.h │ │ ├── scan_image_pyramid.h │ │ ├── scan_image_pyramid_abstract.h │ │ ├── scan_image_pyramid_tools.h │ │ ├── scan_image_pyramid_tools_abstract.h │ │ ├── setup_hashed_features.h │ │ ├── setup_hashed_features_abstract.h │ │ ├── shape_predictor.h │ │ ├── shape_predictor_abstract.h │ │ ├── shape_predictor_trainer.h │ │ └── shape_predictor_trainer_abstract.h │ ├── image_processing.h │ ├── image_saver/ │ │ ├── dng_shared.h │ │ ├── image_saver.h │ │ ├── image_saver_abstract.h │ │ ├── save_jpeg.cpp │ │ ├── save_jpeg.h │ │ ├── save_jpeg_abstract.h │ │ ├── save_jxl.cpp │ │ ├── save_jxl.h │ │ ├── save_jxl_abstract.h │ │ ├── save_png.cpp │ │ ├── save_png.h │ │ ├── save_png_abstract.h │ │ ├── save_webp.cpp │ │ ├── save_webp.h │ │ └── save_webp_abstract.h │ ├── image_transforms/ │ │ ├── assign_image.h │ │ ├── assign_image_abstract.h │ │ ├── colormaps.h │ │ ├── colormaps_abstract.h │ │ ├── draw.h │ │ ├── draw_abstract.h │ │ ├── edge_detector.h │ │ ├── edge_detector_abstract.h │ │ ├── equalize_histogram.h │ │ ├── equalize_histogram_abstract.h │ │ ├── fhog.h │ │ ├── fhog_abstract.h │ │ ├── hough_transform.h │ │ ├── hough_transform_abstract.h │ │ ├── image_pyramid.h │ │ ├── image_pyramid_abstract.h │ │ ├── integral_image.h │ │ ├── integral_image_abstract.h │ │ ├── interpolation.h │ │ ├── interpolation_abstract.h │ │ ├── label_connected_blobs.h │ │ ├── label_connected_blobs_abstract.h │ │ ├── lbp.h │ │ ├── lbp_abstract.h │ │ ├── morphological_operations.h │ │ ├── morphological_operations_abstract.h │ │ ├── random_color_transform.h │ │ ├── random_color_transform_abstract.h │ │ ├── random_cropper.h │ │ ├── random_cropper_abstract.h │ │ ├── segment_image.h │ │ ├── segment_image_abstract.h │ │ ├── spatial_filtering.h │ │ ├── spatial_filtering_abstract.h │ │ ├── thresholding.h │ │ └── thresholding_abstract.h │ ├── image_transforms.h │ ├── interfaces/ │ │ ├── cmd_line_parser_option.h │ │ ├── enumerable.h │ │ ├── map_pair.h │ │ └── remover.h │ ├── invoke.h │ ├── iomanip │ ├── iosfwd │ ├── iosockstream/ │ │ ├── iosockstream.h │ │ └── iosockstream_abstract.h │ ├── iosockstream.h │ ├── iostream │ ├── is_kind.h │ ├── istream │ ├── java/ │ │ ├── CMakeLists.txt │ │ ├── cmake_swig_jni │ │ ├── java_array.h │ │ ├── run_test.sh │ │ ├── swig_api.h │ │ └── swig_test.java │ ├── linker/ │ │ ├── linker_kernel_1.cpp │ │ ├── linker_kernel_1.h │ │ └── linker_kernel_abstract.h │ ├── linker.h │ ├── locale │ ├── logger/ │ │ ├── extra_logger_headers.cpp │ │ ├── extra_logger_headers.h │ │ ├── logger_config_file.cpp │ │ ├── logger_config_file.h │ │ ├── logger_kernel_1.cpp │ │ ├── logger_kernel_1.h │ │ └── logger_kernel_abstract.h │ ├── logger.h │ ├── lsh/ │ │ ├── create_random_projection_hash.h │ │ ├── create_random_projection_hash_abstract.h │ │ ├── hashes.h │ │ ├── hashes_abstract.h │ │ ├── projection_hash.h │ │ └── projection_hash_abstract.h │ ├── lsh.h │ ├── lz77_buffer/ │ │ ├── lz77_buffer_kernel_1.h │ │ ├── lz77_buffer_kernel_2.h │ │ ├── lz77_buffer_kernel_abstract.h │ │ └── lz77_buffer_kernel_c.h │ ├── lz77_buffer.h │ ├── lzp_buffer/ │ │ ├── lzp_buffer_kernel_1.h │ │ ├── lzp_buffer_kernel_2.h │ │ ├── lzp_buffer_kernel_abstract.h │ │ └── lzp_buffer_kernel_c.h │ ├── lzp_buffer.h │ ├── manifold_regularization/ │ │ ├── linear_manifold_regularizer.h │ │ └── linear_manifold_regularizer_abstract.h │ ├── manifold_regularization.h │ ├── map/ │ │ ├── map_kernel_1.h │ │ ├── map_kernel_abstract.h │ │ └── map_kernel_c.h │ ├── map.h │ ├── math/ │ │ ├── bessel.h │ │ ├── details/ │ │ │ └── bessel.h │ │ └── windows.h │ ├── math.h │ ├── matlab/ │ │ ├── CMakeLists.txt │ │ ├── README.txt │ │ ├── call_matlab.h │ │ ├── cmake_mex_wrapper │ │ ├── example.m │ │ ├── example_mex_callback.cpp │ │ ├── example_mex_class.cpp │ │ ├── example_mex_function.cpp │ │ ├── example_mex_struct.cpp │ │ ├── mex_wrapper.cpp │ │ ├── sig_traits.h │ │ ├── subprocess_stream.cpp │ │ └── subprocess_stream.h │ ├── matrix/ │ │ ├── cblas_constants.h │ │ ├── lapack/ │ │ │ ├── fortran_id.h │ │ │ ├── gees.h │ │ │ ├── geev.h │ │ │ ├── geqrf.h │ │ │ ├── gesdd.h │ │ │ ├── gesvd.h │ │ │ ├── getrf.h │ │ │ ├── ormqr.h │ │ │ ├── pbtrf.h │ │ │ ├── potrf.h │ │ │ ├── syev.h │ │ │ └── syevr.h │ │ ├── matrix.h │ │ ├── matrix_abstract.h │ │ ├── matrix_assign.h │ │ ├── matrix_assign_fwd.h │ │ ├── matrix_blas_bindings.h │ │ ├── matrix_cholesky.h │ │ ├── matrix_conj_trans.h │ │ ├── matrix_conv.h │ │ ├── matrix_conv_abstract.h │ │ ├── matrix_data_layout.h │ │ ├── matrix_data_layout_abstract.h │ │ ├── matrix_default_mul.h │ │ ├── matrix_eigenvalue.h │ │ ├── matrix_exp.h │ │ ├── matrix_exp_abstract.h │ │ ├── matrix_expressions.h │ │ ├── matrix_fft.h │ │ ├── matrix_fft_abstract.h │ │ ├── matrix_fwd.h │ │ ├── matrix_generic_image.h │ │ ├── matrix_la.h │ │ ├── matrix_la_abstract.h │ │ ├── matrix_lu.h │ │ ├── matrix_mat.h │ │ ├── matrix_mat_abstract.h │ │ ├── matrix_math_functions.h │ │ ├── matrix_math_functions_abstract.h │ │ ├── matrix_op.h │ │ ├── matrix_qr.h │ │ ├── matrix_read_from_istream.h │ │ ├── matrix_subexp.h │ │ ├── matrix_subexp_abstract.h │ │ ├── matrix_trsm.h │ │ ├── matrix_utilities.h │ │ ├── matrix_utilities_abstract.h │ │ ├── symmetric_matrix_cache.h │ │ └── symmetric_matrix_cache_abstract.h │ ├── matrix.h │ ├── md5/ │ │ ├── md5_kernel_1.cpp │ │ ├── md5_kernel_1.h │ │ └── md5_kernel_abstract.h │ ├── md5.h │ ├── media/ │ │ ├── ffmpeg_demuxer.h │ │ ├── ffmpeg_details.h │ │ ├── ffmpeg_muxer.h │ │ ├── ffmpeg_utils.h │ │ └── sink.h │ ├── media.h │ ├── member_function_pointer/ │ │ ├── make_mfp.h │ │ ├── make_mfp_abstract.h │ │ ├── member_function_pointer_kernel_1.h │ │ └── member_function_pointer_kernel_abstract.h │ ├── member_function_pointer.h │ ├── memory_manager/ │ │ ├── memory_manager_kernel_1.h │ │ ├── memory_manager_kernel_2.h │ │ ├── memory_manager_kernel_3.h │ │ └── memory_manager_kernel_abstract.h │ ├── memory_manager.h │ ├── memory_manager_global/ │ │ ├── memory_manager_global_kernel_1.h │ │ └── memory_manager_global_kernel_abstract.h │ ├── memory_manager_global.h │ ├── memory_manager_stateless/ │ │ ├── memory_manager_stateless_kernel_1.h │ │ ├── memory_manager_stateless_kernel_2.h │ │ └── memory_manager_stateless_kernel_abstract.h │ ├── memory_manager_stateless.h │ ├── metaprogramming.h │ ├── misc_api/ │ │ ├── misc_api_kernel_1.cpp │ │ ├── misc_api_kernel_1.h │ │ ├── misc_api_kernel_2.cpp │ │ ├── misc_api_kernel_2.h │ │ ├── misc_api_kernel_abstract.h │ │ ├── misc_api_shared.h │ │ ├── posix.h │ │ └── windows.h │ ├── misc_api.h │ ├── mlp/ │ │ ├── mlp_kernel_1.h │ │ ├── mlp_kernel_abstract.h │ │ └── mlp_kernel_c.h │ ├── mlp.h │ ├── noncopyable.h │ ├── numeric_constants.h │ ├── numerical_integration/ │ │ ├── integrate_function_adapt_simpson.h │ │ └── integrate_function_adapt_simpson_abstract.h │ ├── numerical_integration.h │ ├── opencv/ │ │ ├── cv_image.h │ │ ├── cv_image_abstract.h │ │ ├── to_open_cv.h │ │ └── to_open_cv_abstract.h │ ├── opencv.h │ ├── optimization/ │ │ ├── elastic_net.h │ │ ├── elastic_net_abstract.h │ │ ├── find_max_factor_graph_nmplp.h │ │ ├── find_max_factor_graph_nmplp_abstract.h │ │ ├── find_max_factor_graph_viterbi.h │ │ ├── find_max_factor_graph_viterbi_abstract.h │ │ ├── find_max_parse_cky.h │ │ ├── find_max_parse_cky_abstract.h │ │ ├── find_optimal_parameters.h │ │ ├── find_optimal_parameters_abstract.h │ │ ├── isotonic_regression.h │ │ ├── isotonic_regression_abstract.h │ │ ├── max_cost_assignment.h │ │ ├── max_cost_assignment_abstract.h │ │ ├── max_sum_submatrix.h │ │ ├── max_sum_submatrix_abstract.h │ │ ├── optimization.h │ │ ├── optimization_abstract.h │ │ ├── optimization_bobyqa.h │ │ ├── optimization_bobyqa_abstract.h │ │ ├── optimization_least_squares.h │ │ ├── optimization_least_squares_abstract.h │ │ ├── optimization_line_search.h │ │ ├── optimization_line_search_abstract.h │ │ ├── optimization_oca.h │ │ ├── optimization_oca_abstract.h │ │ ├── optimization_search_strategies.h │ │ ├── optimization_search_strategies_abstract.h │ │ ├── optimization_solve_qp2_using_smo.h │ │ ├── optimization_solve_qp2_using_smo_abstract.h │ │ ├── optimization_solve_qp3_using_smo.h │ │ ├── optimization_solve_qp3_using_smo_abstract.h │ │ ├── optimization_solve_qp_using_smo.h │ │ ├── optimization_solve_qp_using_smo_abstract.h │ │ ├── optimization_stop_strategies.h │ │ ├── optimization_stop_strategies_abstract.h │ │ ├── optimization_trust_region.h │ │ └── optimization_trust_region_abstract.h │ ├── optimization.h │ ├── optional.h │ ├── ostream │ ├── overloaded.h │ ├── pipe/ │ │ ├── pipe_kernel_1.h │ │ └── pipe_kernel_abstract.h │ ├── pipe.h │ ├── pixel.h │ ├── platform.h │ ├── python/ │ │ ├── numpy_image.h │ │ ├── pyassert.h │ │ ├── pybind_utils.h │ │ └── serialize_pickle.h │ ├── python.h │ ├── quantum_computing/ │ │ ├── quantum_computing.h │ │ └── quantum_computing_abstract.h │ ├── quantum_computing.h │ ├── queue/ │ │ ├── queue_kernel_1.h │ │ ├── queue_kernel_2.h │ │ ├── queue_kernel_abstract.h │ │ ├── queue_kernel_c.h │ │ ├── queue_sort_1.h │ │ └── queue_sort_abstract.h │ ├── queue.h │ ├── rand/ │ │ ├── mersenne_twister.h │ │ ├── rand_kernel_1.h │ │ └── rand_kernel_abstract.h │ ├── rand.h │ ├── random_forest/ │ │ ├── random_forest_regression.h │ │ └── random_forest_regression_abstract.h │ ├── random_forest.h │ ├── ref.h │ ├── reference_counter/ │ │ ├── reference_counter_kernel_1.h │ │ └── reference_counter_kernel_abstract.h │ ├── reference_counter.h │ ├── revision.h.in │ ├── scope.h │ ├── sequence/ │ │ ├── sequence_compare_1.h │ │ ├── sequence_compare_abstract.h │ │ ├── sequence_kernel_1.h │ │ ├── sequence_kernel_2.h │ │ ├── sequence_kernel_abstract.h │ │ ├── sequence_kernel_c.h │ │ ├── sequence_sort_1.h │ │ ├── sequence_sort_2.h │ │ └── sequence_sort_abstract.h │ ├── sequence.h │ ├── serialize.h │ ├── server/ │ │ ├── server_http.cpp │ │ ├── server_http.h │ │ ├── server_http_abstract.h │ │ ├── server_iostream.cpp │ │ ├── server_iostream.h │ │ ├── server_iostream_abstract.h │ │ ├── server_kernel.cpp │ │ ├── server_kernel.h │ │ └── server_kernel_abstract.h │ ├── server.h │ ├── set/ │ │ ├── set_compare_1.h │ │ ├── set_compare_abstract.h │ │ ├── set_kernel_1.h │ │ ├── set_kernel_abstract.h │ │ └── set_kernel_c.h │ ├── set.h │ ├── set_utils/ │ │ ├── set_utils.h │ │ └── set_utils_abstract.h │ ├── set_utils.h │ ├── simd/ │ │ ├── simd4f.h │ │ ├── simd4i.h │ │ ├── simd8f.h │ │ ├── simd8i.h │ │ └── simd_check.h │ ├── simd.h │ ├── sliding_buffer/ │ │ ├── circular_buffer.h │ │ ├── circular_buffer_abstract.h │ │ ├── sliding_buffer_kernel_1.h │ │ ├── sliding_buffer_kernel_abstract.h │ │ └── sliding_buffer_kernel_c.h │ ├── sliding_buffer.h │ ├── smart_pointers/ │ │ ├── scoped_ptr.h │ │ ├── shared_ptr.h │ │ ├── shared_ptr_abstract.h │ │ ├── shared_ptr_thread_safe.h │ │ ├── shared_ptr_thread_safe_abstract.h │ │ ├── weak_ptr.h │ │ └── weak_ptr_abstract.h │ ├── smart_pointers.h │ ├── smart_pointers_thread_safe.h │ ├── sockets/ │ │ ├── posix.h │ │ ├── sockets_extensions.cpp │ │ ├── sockets_extensions.h │ │ ├── sockets_extensions_abstract.h │ │ ├── sockets_kernel_1.cpp │ │ ├── sockets_kernel_1.h │ │ ├── sockets_kernel_2.cpp │ │ ├── sockets_kernel_2.h │ │ ├── sockets_kernel_abstract.h │ │ └── windows.h │ ├── sockets.h │ ├── sockstreambuf/ │ │ ├── sockstreambuf.cpp │ │ ├── sockstreambuf.h │ │ ├── sockstreambuf_abstract.h │ │ ├── sockstreambuf_unbuffered.cpp │ │ └── sockstreambuf_unbuffered.h │ ├── sockstreambuf.h │ ├── sort.h │ ├── sparse_vector.h │ ├── sqlite/ │ │ ├── sqlite.h │ │ ├── sqlite_abstract.h │ │ ├── sqlite_tools.h │ │ └── sqlite_tools_abstract.h │ ├── sqlite.h │ ├── sstream │ ├── stack/ │ │ ├── stack_kernel_1.h │ │ ├── stack_kernel_abstract.h │ │ └── stack_kernel_c.h │ ├── stack.h │ ├── stack_trace.cpp │ ├── stack_trace.h │ ├── static_map/ │ │ ├── static_map_kernel_1.h │ │ ├── static_map_kernel_abstract.h │ │ └── static_map_kernel_c.h │ ├── static_map.h │ ├── static_set/ │ │ ├── static_set_compare_1.h │ │ ├── static_set_compare_abstract.h │ │ ├── static_set_kernel_1.h │ │ ├── static_set_kernel_abstract.h │ │ └── static_set_kernel_c.h │ ├── static_set.h │ ├── statistics/ │ │ ├── average_precision.h │ │ ├── average_precision_abstract.h │ │ ├── cca.h │ │ ├── cca_abstract.h │ │ ├── dpca.h │ │ ├── dpca_abstract.h │ │ ├── image_feature_sampling.h │ │ ├── image_feature_sampling_abstract.h │ │ ├── lda.h │ │ ├── lda_abstract.h │ │ ├── random_subset_selector.h │ │ ├── random_subset_selector_abstract.h │ │ ├── running_gradient.h │ │ ├── running_gradient_abstract.h │ │ ├── sammon.h │ │ ├── sammon_abstract.h │ │ ├── statistics.h │ │ ├── statistics_abstract.h │ │ ├── vector_normalizer_frobmetric.h │ │ └── vector_normalizer_frobmetric_abstract.h │ ├── statistics.h │ ├── std_allocator.h │ ├── stl_checked/ │ │ ├── std_vector_c.h │ │ └── std_vector_c_abstract.h │ ├── stl_checked.h │ ├── string/ │ │ ├── cassert │ │ ├── iomanip │ │ ├── iosfwd │ │ ├── iostream │ │ ├── locale │ │ ├── string.h │ │ └── string_abstract.h │ ├── string.h │ ├── svm/ │ │ ├── active_learning.h │ │ ├── active_learning_abstract.h │ │ ├── assignment_function.h │ │ ├── assignment_function_abstract.h │ │ ├── auto.cpp │ │ ├── auto.h │ │ ├── auto_abstract.h │ │ ├── cross_validate_assignment_trainer.h │ │ ├── cross_validate_assignment_trainer_abstract.h │ │ ├── cross_validate_graph_labeling_trainer.h │ │ ├── cross_validate_graph_labeling_trainer_abstract.h │ │ ├── cross_validate_multiclass_trainer.h │ │ ├── cross_validate_multiclass_trainer_abstract.h │ │ ├── cross_validate_object_detection_trainer.h │ │ ├── cross_validate_object_detection_trainer_abstract.h │ │ ├── cross_validate_regression_trainer.h │ │ ├── cross_validate_regression_trainer_abstract.h │ │ ├── cross_validate_sequence_labeler.h │ │ ├── cross_validate_sequence_labeler_abstract.h │ │ ├── cross_validate_sequence_segmenter.h │ │ ├── cross_validate_sequence_segmenter_abstract.h │ │ ├── cross_validate_track_association_trainer.h │ │ ├── cross_validate_track_association_trainer_abstract.h │ │ ├── empirical_kernel_map.h │ │ ├── empirical_kernel_map_abstract.h │ │ ├── feature_ranking.h │ │ ├── feature_ranking_abstract.h │ │ ├── function.h │ │ ├── function_abstract.h │ │ ├── kcentroid.h │ │ ├── kcentroid_abstract.h │ │ ├── kcentroid_overloads.h │ │ ├── kernel.h │ │ ├── kernel_abstract.h │ │ ├── kernel_matrix.h │ │ ├── kernel_matrix_abstract.h │ │ ├── kkmeans.h │ │ ├── kkmeans_abstract.h │ │ ├── krls.h │ │ ├── krls_abstract.h │ │ ├── krr_trainer.h │ │ ├── krr_trainer_abstract.h │ │ ├── linearly_independent_subset_finder.h │ │ ├── linearly_independent_subset_finder_abstract.h │ │ ├── multiclass_tools.h │ │ ├── multiclass_tools_abstract.h │ │ ├── null_df.h │ │ ├── null_trainer.h │ │ ├── null_trainer_abstract.h │ │ ├── num_nonnegative_weights.h │ │ ├── one_vs_all_decision_function.h │ │ ├── one_vs_all_decision_function_abstract.h │ │ ├── one_vs_all_trainer.h │ │ ├── one_vs_all_trainer_abstract.h │ │ ├── one_vs_one_decision_function.h │ │ ├── one_vs_one_decision_function_abstract.h │ │ ├── one_vs_one_trainer.h │ │ ├── one_vs_one_trainer_abstract.h │ │ ├── pegasos.h │ │ ├── pegasos_abstract.h │ │ ├── ranking_tools.h │ │ ├── ranking_tools_abstract.h │ │ ├── rbf_network.h │ │ ├── rbf_network_abstract.h │ │ ├── reduced.h │ │ ├── reduced_abstract.h │ │ ├── rls.h │ │ ├── rls_abstract.h │ │ ├── roc_trainer.h │ │ ├── roc_trainer_abstract.h │ │ ├── rr_trainer.h │ │ ├── rr_trainer_abstract.h │ │ ├── rvm.h │ │ ├── rvm_abstract.h │ │ ├── sequence_labeler.h │ │ ├── sequence_labeler_abstract.h │ │ ├── sequence_segmenter.h │ │ ├── sequence_segmenter_abstract.h │ │ ├── simplify_linear_decision_function.h │ │ ├── simplify_linear_decision_function_abstract.h │ │ ├── sort_basis_vectors.h │ │ ├── sort_basis_vectors_abstract.h │ │ ├── sparse_kernel.h │ │ ├── sparse_kernel_abstract.h │ │ ├── sparse_vector.h │ │ ├── sparse_vector_abstract.h │ │ ├── structural_assignment_trainer.h │ │ ├── structural_assignment_trainer_abstract.h │ │ ├── structural_graph_labeling_trainer.h │ │ ├── structural_graph_labeling_trainer_abstract.h │ │ ├── structural_object_detection_trainer.h │ │ ├── structural_object_detection_trainer_abstract.h │ │ ├── structural_sequence_labeling_trainer.h │ │ ├── structural_sequence_labeling_trainer_abstract.h │ │ ├── structural_sequence_segmentation_trainer.h │ │ ├── structural_sequence_segmentation_trainer_abstract.h │ │ ├── structural_svm_assignment_problem.h │ │ ├── structural_svm_assignment_problem_abstract.h │ │ ├── structural_svm_distributed.h │ │ ├── structural_svm_distributed_abstract.h │ │ ├── structural_svm_graph_labeling_problem.h │ │ ├── structural_svm_graph_labeling_problem_abstract.h │ │ ├── structural_svm_object_detection_problem.h │ │ ├── structural_svm_object_detection_problem_abstract.h │ │ ├── structural_svm_problem.h │ │ ├── structural_svm_problem_abstract.h │ │ ├── structural_svm_problem_threaded.h │ │ ├── structural_svm_problem_threaded_abstract.h │ │ ├── structural_svm_sequence_labeling_problem.h │ │ ├── structural_svm_sequence_labeling_problem_abstract.h │ │ ├── structural_track_association_trainer.h │ │ ├── structural_track_association_trainer_abstract.h │ │ ├── svm.h │ │ ├── svm_abstract.h │ │ ├── svm_c_ekm_trainer.h │ │ ├── svm_c_ekm_trainer_abstract.h │ │ ├── svm_c_linear_dcd_trainer.h │ │ ├── svm_c_linear_dcd_trainer_abstract.h │ │ ├── svm_c_linear_trainer.h │ │ ├── svm_c_linear_trainer_abstract.h │ │ ├── svm_c_trainer.h │ │ ├── svm_c_trainer_abstract.h │ │ ├── svm_multiclass_linear_trainer.h │ │ ├── svm_multiclass_linear_trainer_abstract.h │ │ ├── svm_nu_trainer.h │ │ ├── svm_nu_trainer_abstract.h │ │ ├── svm_one_class_trainer.h │ │ ├── svm_one_class_trainer_abstract.h │ │ ├── svm_rank_trainer.h │ │ ├── svm_rank_trainer_abstract.h │ │ ├── svm_threaded.h │ │ ├── svm_threaded_abstract.h │ │ ├── svr_linear_trainer.h │ │ ├── svr_linear_trainer_abstract.h │ │ ├── svr_trainer.h │ │ ├── svr_trainer_abstract.h │ │ ├── track_association_function.h │ │ └── track_association_function_abstract.h │ ├── svm.h │ ├── svm_threaded.h │ ├── sync_extension/ │ │ ├── sync_extension_kernel_1.h │ │ └── sync_extension_kernel_abstract.h │ ├── sync_extension.h │ ├── test/ │ │ ├── CMakeLists.txt │ │ ├── WINDOWS_build_and_run_all_unit_tests.bat │ │ ├── active_learning.cpp │ │ ├── any.cpp │ │ ├── any_function.cpp │ │ ├── array.cpp │ │ ├── array2d.cpp │ │ ├── assignment_learning.cpp │ │ ├── base64.cpp │ │ ├── bayes_nets.cpp │ │ ├── bigint.cpp │ │ ├── binary_search_tree.h │ │ ├── binary_search_tree_kernel_1a.cpp │ │ ├── binary_search_tree_kernel_2a.cpp │ │ ├── binary_search_tree_mm1.cpp │ │ ├── binary_search_tree_mm2.cpp │ │ ├── blas_bindings/ │ │ │ ├── CMakeLists.txt │ │ │ ├── blas_bindings_dot.cpp │ │ │ ├── blas_bindings_gemm.cpp │ │ │ ├── blas_bindings_gemv.cpp │ │ │ ├── blas_bindings_ger.cpp │ │ │ ├── blas_bindings_scal_axpy.cpp │ │ │ └── vector.cpp │ │ ├── bridge.cpp │ │ ├── bsp.cpp │ │ ├── byte_orderer.cpp │ │ ├── cca.cpp │ │ ├── checkerboard.h │ │ ├── clustering.cpp │ │ ├── cmd_line_parser.cpp │ │ ├── cmd_line_parser.h │ │ ├── cmd_line_parser_wchar_t.cpp │ │ ├── compress_stream.cpp │ │ ├── conditioning_class.cpp │ │ ├── conditioning_class.h │ │ ├── conditioning_class_c.cpp │ │ ├── config_reader.cpp │ │ ├── constexpr_if.cpp │ │ ├── correlation_tracker.cpp │ │ ├── crc32.cpp │ │ ├── create_iris_datafile.cpp │ │ ├── create_iris_datafile.h │ │ ├── cublas.cpp │ │ ├── data_io.cpp │ │ ├── directed_graph.cpp │ │ ├── discriminant_pca.cpp │ │ ├── disjoint_subsets.cpp │ │ ├── disjoint_subsets_sized.cpp │ │ ├── dnn.cpp │ │ ├── ekm_and_lisf.cpp │ │ ├── elastic_net.cpp │ │ ├── empirical_kernel_map.cpp │ │ ├── entropy_coder.cpp │ │ ├── entropy_encoder_model.cpp │ │ ├── example.cpp │ │ ├── example_args.cpp │ │ ├── examples/ │ │ │ └── CMakeLists.txt │ │ ├── face.cpp │ │ ├── ffmpeg.cpp │ │ ├── ffmpeg_data/ │ │ │ ├── 116-288045-0000.flac │ │ │ ├── 116-288045-0001.m4a │ │ │ ├── LICENSE.TXT │ │ │ ├── MOT17-13-SDP-raw.h265 │ │ │ ├── MOT20-08-raw_shorter.h264 │ │ │ └── details.cfg │ │ ├── fft.cpp │ │ ├── fftr_good_data.h │ │ ├── fhog.cpp │ │ ├── filtering.cpp │ │ ├── find_max_factor_graph_nmplp.cpp │ │ ├── find_max_factor_graph_viterbi.cpp │ │ ├── find_optimal_parameters.cpp │ │ ├── geometry.cpp │ │ ├── global_optimization.cpp │ │ ├── graph.cpp │ │ ├── graph_cuts.cpp │ │ ├── graph_labeler.cpp │ │ ├── gui/ │ │ │ ├── CMakeLists.txt │ │ │ └── main.cpp │ │ ├── hash.cpp │ │ ├── hash_map.cpp │ │ ├── hash_set.cpp │ │ ├── hash_table.cpp │ │ ├── hog_image.cpp │ │ ├── image.cpp │ │ ├── invoke.cpp │ │ ├── iosockstream.cpp │ │ ├── is_same_object.cpp │ │ ├── isotonic_regression.cpp │ │ ├── kcentroid.cpp │ │ ├── kernel_matrix.cpp │ │ ├── kmeans.cpp │ │ ├── learning_to_track.cpp │ │ ├── least_squares.cpp │ │ ├── linear_manifold_regularizer.cpp │ │ ├── lspi.cpp │ │ ├── lz77_buffer.cpp │ │ ├── main.cpp │ │ ├── makefile │ │ ├── map.cpp │ │ ├── math.cpp │ │ ├── matrix.cpp │ │ ├── matrix2.cpp │ │ ├── matrix3.cpp │ │ ├── matrix4.cpp │ │ ├── matrix_chol.cpp │ │ ├── matrix_eig.cpp │ │ ├── matrix_lu.cpp │ │ ├── matrix_qr.cpp │ │ ├── max_cost_assignment.cpp │ │ ├── max_sum_submatrix.cpp │ │ ├── md5.cpp │ │ ├── member_function_pointer.cpp │ │ ├── metaprogramming.cpp │ │ ├── mpc.cpp │ │ ├── multithreaded_object.cpp │ │ ├── numerical_integration.cpp │ │ ├── object_detector.cpp │ │ ├── oca.cpp │ │ ├── one_vs_all_trainer.cpp │ │ ├── one_vs_one_trainer.cpp │ │ ├── opt_qp_solver.cpp │ │ ├── optimization.cpp │ │ ├── optimization_test_functions.cpp │ │ ├── optimization_test_functions.h │ │ ├── optional.cpp │ │ ├── parallel_for.cpp │ │ ├── parse.cpp │ │ ├── pipe.cpp │ │ ├── pixel.cpp │ │ ├── probabilistic.cpp │ │ ├── pyramid_down.cpp │ │ ├── queue.cpp │ │ ├── rand.cpp │ │ ├── random_forest.cpp │ │ ├── ranking.cpp │ │ ├── read_write_mutex.cpp │ │ ├── reference_counter.cpp │ │ ├── rls.cpp │ │ ├── sammon.cpp │ │ ├── scan_image.cpp │ │ ├── scope.cpp │ │ ├── sequence.cpp │ │ ├── sequence_labeler.cpp │ │ ├── sequence_segmenter.cpp │ │ ├── serialize.cpp │ │ ├── set.cpp │ │ ├── sldf.cpp │ │ ├── sliding_buffer.cpp │ │ ├── smart_pointers.cpp │ │ ├── sockets.cpp │ │ ├── sockets2.cpp │ │ ├── sockstreambuf.cpp │ │ ├── sparse_vector.cpp │ │ ├── stack.cpp │ │ ├── static_map.cpp │ │ ├── static_set.cpp │ │ ├── statistics.cpp │ │ ├── std_vector_c.cpp │ │ ├── stft_good_data.h │ │ ├── string.cpp │ │ ├── svm.cpp │ │ ├── svm_c_linear.cpp │ │ ├── svm_c_linear_dcd.cpp │ │ ├── svm_multiclass_linear.cpp │ │ ├── svm_struct.cpp │ │ ├── svr_linear_trainer.cpp │ │ ├── symmetric_matrix_cache.cpp │ │ ├── te.cpp │ │ ├── tester.cpp │ │ ├── tester.h │ │ ├── thread_pool.cpp │ │ ├── threads.cpp │ │ ├── timer.cpp │ │ ├── tokenizer.cpp │ │ ├── tools/ │ │ │ └── CMakeLists.txt │ │ ├── trust_region.cpp │ │ ├── tuple.cpp │ │ ├── type_safe_union.cpp │ │ └── vectorstream.cpp │ ├── test_for_odr_violations.cpp │ ├── test_for_odr_violations.h │ ├── threads/ │ │ ├── async.cpp │ │ ├── async.h │ │ ├── async_abstract.h │ │ ├── auto_mutex_extension.h │ │ ├── auto_mutex_extension_abstract.h │ │ ├── auto_unlock_extension.h │ │ ├── auto_unlock_extension_abstract.h │ │ ├── create_new_thread_extension.h │ │ ├── create_new_thread_extension_abstract.h │ │ ├── multithreaded_object_extension.cpp │ │ ├── multithreaded_object_extension.h │ │ ├── multithreaded_object_extension_abstract.h │ │ ├── parallel_for_extension.h │ │ ├── parallel_for_extension_abstract.h │ │ ├── posix.h │ │ ├── read_write_mutex_extension.h │ │ ├── read_write_mutex_extension_abstract.h │ │ ├── rmutex_extension.h │ │ ├── rmutex_extension_abstract.h │ │ ├── rsignaler_extension.h │ │ ├── rsignaler_extension_abstract.h │ │ ├── thread_function_extension.h │ │ ├── thread_function_extension_abstract.h │ │ ├── thread_pool_extension.cpp │ │ ├── thread_pool_extension.h │ │ ├── thread_pool_extension_abstract.h │ │ ├── thread_specific_data_extension.h │ │ ├── thread_specific_data_extension_abstract.h │ │ ├── threaded_object_extension.cpp │ │ ├── threaded_object_extension.h │ │ ├── threaded_object_extension_abstract.h │ │ ├── threads_kernel.h │ │ ├── threads_kernel_1.cpp │ │ ├── threads_kernel_1.h │ │ ├── threads_kernel_2.cpp │ │ ├── threads_kernel_2.h │ │ ├── threads_kernel_abstract.h │ │ ├── threads_kernel_shared.cpp │ │ ├── threads_kernel_shared.h │ │ └── windows.h │ ├── threads.h │ ├── time_this.h │ ├── timeout/ │ │ ├── timeout.h │ │ └── timeout_abstract.h │ ├── timeout.h │ ├── timer/ │ │ ├── timer.cpp │ │ ├── timer.h │ │ ├── timer_abstract.h │ │ └── timer_heavy.h │ ├── timer.h │ ├── timing.h │ ├── tokenizer/ │ │ ├── bpe_tokenizer.h │ │ ├── bpe_tokenizer_abstract.h │ │ ├── tokenizer_kernel_1.cpp │ │ ├── tokenizer_kernel_1.h │ │ ├── tokenizer_kernel_abstract.h │ │ └── tokenizer_kernel_c.h │ ├── tokenizer.h │ ├── tuple/ │ │ ├── tuple.h │ │ └── tuple_abstract.h │ ├── tuple.h │ ├── type_safe_union/ │ │ ├── type_safe_union_kernel.h │ │ └── type_safe_union_kernel_abstract.h │ ├── type_safe_union.h │ ├── type_traits.h │ ├── uintn.h │ ├── unicode/ │ │ ├── unicode.cpp │ │ ├── unicode.h │ │ └── unicode_abstract.h │ ├── unicode.h │ ├── unordered_pair.h │ ├── utility.h │ ├── vectorstream/ │ │ ├── unserialize.h │ │ ├── unserialize_abstract.h │ │ ├── vectorstream.h │ │ └── vectorstream_abstract.h │ ├── vectorstream.h │ ├── windows_magic.h │ ├── xml_parser/ │ │ ├── xml_parser_kernel_1.h │ │ ├── xml_parser_kernel_abstract.h │ │ └── xml_parser_kernel_interfaces.h │ └── xml_parser.h ├── docs/ │ ├── .logger_revnum │ ├── README.txt │ ├── bash_helper_functions │ ├── docs/ │ │ ├── algorithms.xml │ │ ├── api.xml │ │ ├── bayes.xml │ │ ├── books.xml │ │ ├── change_log.xml │ │ ├── compile.xml │ │ ├── compression.xml │ │ ├── containers.xml │ │ ├── dlib.css │ │ ├── dlib.js │ │ ├── enable_if.html │ │ ├── faq.xml │ │ ├── find_max_global_example.webm │ │ ├── graph_tools.xml │ │ ├── howto_contribute.xml │ │ ├── imaging.xml │ │ ├── index.xml │ │ ├── intro.xml │ │ ├── kernel_1a.txt │ │ ├── kernel_1a.xml │ │ ├── kernel_1b.txt │ │ ├── kernel_1b.xml │ │ ├── kernel_1c.txt │ │ ├── kernel_1c.xml │ │ ├── kernel_1da.txt │ │ ├── kernel_1da.xml │ │ ├── kernel_1db.txt │ │ ├── kernel_1db.xml │ │ ├── kernel_1ea.txt │ │ ├── kernel_1ea.xml │ │ ├── kernel_1eb.txt │ │ ├── kernel_1eb.xml │ │ ├── kernel_1ec.txt │ │ ├── kernel_1ec.xml │ │ ├── kernel_2a.txt │ │ ├── kernel_2a.xml │ │ ├── kernel_3a.txt │ │ ├── kernel_3a.xml │ │ ├── kernel_3b.txt │ │ ├── kernel_3b.xml │ │ ├── license.xml │ │ ├── linear_algebra.xml │ │ ├── main_menu.xml │ │ ├── metaprogramming.xml │ │ ├── ml.xml │ │ ├── ml_guide.dia │ │ ├── network.xml │ │ ├── old_release_notes.xml │ │ ├── optimization.xml │ │ ├── other.xml │ │ ├── parsing.xml │ │ ├── python/ │ │ │ ├── conf.py │ │ │ ├── generate_dlib_listing.py │ │ │ └── index.rst │ │ ├── release_notes.xml │ │ ├── stylesheet.xsl │ │ ├── term_index.xml │ │ └── watershed.webm │ ├── makedocs │ ├── makerel │ ├── testenv │ └── testenv_rel ├── examples/ │ ├── 3d_point_cloud_ex.cpp │ ├── CMakeLists.txt │ ├── LICENSE_FOR_EXAMPLE_PROGRAMS.txt │ ├── assignment_learning_ex.cpp │ ├── bayes_net_ex.cpp │ ├── bayes_net_from_disk_ex.cpp │ ├── bayes_net_gui_ex.cpp │ ├── bridge_ex.cpp │ ├── bsp_ex.cpp │ ├── compress_stream_ex.cpp │ ├── config.txt │ ├── config_reader_ex.cpp │ ├── custom_trainer_ex.cpp │ ├── dir_nav_ex.cpp │ ├── dnn_dcgan_train_ex.cpp │ ├── dnn_face_recognition_ex.cpp │ ├── dnn_imagenet_ex.cpp │ ├── dnn_imagenet_train_ex.cpp │ ├── dnn_inception_ex.cpp │ ├── dnn_instance_segmentation_ex.cpp │ ├── dnn_instance_segmentation_ex.h │ ├── dnn_instance_segmentation_train_ex.cpp │ ├── dnn_introduction2_ex.cpp │ ├── dnn_introduction3_ex.cpp │ ├── dnn_introduction_ex.cpp │ ├── dnn_metric_learning_ex.cpp │ ├── dnn_metric_learning_on_images_ex.cpp │ ├── dnn_mmod_dog_hipsterizer.cpp │ ├── dnn_mmod_ex.cpp │ ├── dnn_mmod_face_detection_ex.cpp │ ├── dnn_mmod_find_cars2_ex.cpp │ ├── dnn_mmod_find_cars_ex.cpp │ ├── dnn_mmod_train_find_cars_ex.cpp │ ├── dnn_self_supervised_learning_ex.cpp │ ├── dnn_semantic_segmentation_ex.cpp │ ├── dnn_semantic_segmentation_ex.h │ ├── dnn_semantic_segmentation_train_ex.cpp │ ├── dnn_yolo_train_ex.cpp │ ├── empirical_kernel_map_ex.cpp │ ├── face_detection_ex.cpp │ ├── face_landmark_detection_ex.cpp │ ├── faces/ │ │ ├── image_metadata_stylesheet.xsl │ │ ├── testing.xml │ │ ├── testing_with_face_landmarks.xml │ │ ├── training.xml │ │ └── training_with_face_landmarks.xml │ ├── ffmpeg_file_to_speaker_ex.cpp │ ├── ffmpeg_info_ex.cpp │ ├── ffmpeg_microphone_to_file_ex.cpp │ ├── ffmpeg_rtsp_ex.cpp │ ├── ffmpeg_screen_grab_ex.cpp │ ├── ffmpeg_video_decoding2_ex.cpp │ ├── ffmpeg_video_decoding_ex.cpp │ ├── ffmpeg_video_demuxing2_ex.cpp │ ├── ffmpeg_video_demuxing_ex.cpp │ ├── ffmpeg_video_encoding_ex.cpp │ ├── ffmpeg_video_muxing_ex.cpp │ ├── ffmpeg_webcam_face_pose_ex.cpp │ ├── fhog_ex.cpp │ ├── fhog_object_detector_ex.cpp │ ├── file_to_code_ex.cpp │ ├── graph_labeling_ex.cpp │ ├── gui_api_ex.cpp │ ├── hough_transform_ex.cpp │ ├── image_ex.cpp │ ├── integrate_function_adapt_simp_ex.cpp │ ├── iosockstream_ex.cpp │ ├── kcentroid_ex.cpp │ ├── kkmeans_ex.cpp │ ├── krls_ex.cpp │ ├── krls_filter_ex.cpp │ ├── krr_classification_ex.cpp │ ├── krr_regression_ex.cpp │ ├── learning_to_track_ex.cpp │ ├── least_squares_ex.cpp │ ├── linear_manifold_regularizer_ex.cpp │ ├── logger_custom_output_ex.cpp │ ├── logger_ex.cpp │ ├── logger_ex_2.cpp │ ├── matrix_ex.cpp │ ├── matrix_expressions_ex.cpp │ ├── max_cost_assignment_ex.cpp │ ├── member_function_pointer_ex.cpp │ ├── mlp_ex.cpp │ ├── model_selection_ex.cpp │ ├── mpc_ex.cpp │ ├── multiclass_classification_ex.cpp │ ├── multithreaded_object_ex.cpp │ ├── object_detector_advanced_ex.cpp │ ├── object_detector_ex.cpp │ ├── one_class_classifiers_ex.cpp │ ├── optimization_ex.cpp │ ├── parallel_for_ex.cpp │ ├── pascal_voc_2012.h │ ├── pipe_ex.cpp │ ├── pipe_ex_2.cpp │ ├── quantum_computing_ex.cpp │ ├── queue_ex.cpp │ ├── random_cropper_ex.cpp │ ├── rank_features_ex.cpp │ ├── resnet.h │ ├── running_stats_ex.cpp │ ├── rvm_ex.cpp │ ├── rvm_regression_ex.cpp │ ├── sequence_labeler_ex.cpp │ ├── sequence_segmenter_ex.cpp │ ├── server_http_ex.cpp │ ├── server_iostream_ex.cpp │ ├── slm_advanced_train_ex.cpp │ ├── slm_basic_train_ex.cpp │ ├── slm_data.h │ ├── slm_defs.h │ ├── sockets_ex.cpp │ ├── sockstreambuf_ex.cpp │ ├── sqlite_ex.cpp │ ├── std_allocator_ex.cpp │ ├── surf_ex.cpp │ ├── svm_c_ex.cpp │ ├── svm_ex.cpp │ ├── svm_pegasos_ex.cpp │ ├── svm_rank_ex.cpp │ ├── svm_sparse_ex.cpp │ ├── svm_struct_ex.cpp │ ├── svr_ex.cpp │ ├── thread_function_ex.cpp │ ├── thread_pool_ex.cpp │ ├── threaded_object_ex.cpp │ ├── threads_ex.cpp │ ├── timer_ex.cpp │ ├── train_object_detector.cpp │ ├── train_shape_predictor_ex.cpp │ ├── using_custom_kernels_ex.cpp │ ├── video_frames/ │ │ └── license.txt │ ├── video_tracking_ex.cpp │ ├── webcam_face_pose_ex.cpp │ └── xml_parser_ex.cpp ├── pyproject.toml ├── python_examples/ │ ├── LICENSE_FOR_EXAMPLE_PROGRAMS.txt │ ├── cnn_face_detector.py │ ├── correlation_tracker.py │ ├── face_alignment.py │ ├── face_clustering.py │ ├── face_detector.py │ ├── face_jitter.py │ ├── face_landmark_detection.py │ ├── face_recognition.py │ ├── find_candidate_object_locations.py │ ├── global_optimization.py │ ├── max_cost_assignment.py │ ├── opencv_webcam_face_detection.py │ ├── requirements.txt │ ├── sequence_segmenter.py │ ├── svm_binary_classifier.py │ ├── svm_rank.py │ ├── svm_struct.py │ ├── train_object_detector.py │ └── train_shape_predictor.py ├── setup.py └── tools/ ├── archive/ │ ├── CMakeLists.txt │ └── train_face_5point_model.cpp ├── convert_dlib_nets_to_caffe/ │ ├── CMakeLists.txt │ ├── main.cpp │ └── running_a_dlib_model_with_caffe_example.py ├── htmlify/ │ ├── CMakeLists.txt │ ├── htmlify.cpp │ ├── to_xml.cpp │ ├── to_xml.h │ └── to_xml_example/ │ ├── example.xml │ ├── output.xml │ ├── stylesheet.xsl │ └── test.cpp ├── imglab/ │ ├── CMakeLists.txt │ ├── README.txt │ ├── convert_imglab_paths_to_relative │ ├── copy_imglab_dataset │ └── src/ │ ├── cluster.cpp │ ├── cluster.h │ ├── common.cpp │ ├── common.h │ ├── convert_idl.cpp │ ├── convert_idl.h │ ├── convert_pascal_v1.cpp │ ├── convert_pascal_v1.h │ ├── convert_pascal_xml.cpp │ ├── convert_pascal_xml.h │ ├── flip_dataset.cpp │ ├── flip_dataset.h │ ├── main.cpp │ ├── metadata_editor.cpp │ └── metadata_editor.h ├── python/ │ ├── CMakeLists.txt │ ├── dlib/ │ │ └── __init__.py.in │ ├── src/ │ │ ├── basic.cpp │ │ ├── cca.cpp │ │ ├── cnn_face_detector.cpp │ │ ├── conversion.h │ │ ├── correlation_tracker.cpp │ │ ├── decision_functions.cpp │ │ ├── dlib.cpp │ │ ├── face_recognition.cpp │ │ ├── global_optimization.cpp │ │ ├── gui.cpp │ │ ├── image.cpp │ │ ├── image2.cpp │ │ ├── image3.cpp │ │ ├── image4.cpp │ │ ├── image_dataset_metadata.cpp │ │ ├── indexing.h │ │ ├── line.cpp │ │ ├── matrix.cpp │ │ ├── numpy_returns.cpp │ │ ├── object_detection.cpp │ │ ├── opaque_types.h │ │ ├── other.cpp │ │ ├── rectangles.cpp │ │ ├── sequence_segmenter.cpp │ │ ├── serialize_object_detector.h │ │ ├── shape_predictor.cpp │ │ ├── shape_predictor.h │ │ ├── simple_object_detector.h │ │ ├── simple_object_detector_py.h │ │ ├── svm_c_trainer.cpp │ │ ├── svm_rank_trainer.cpp │ │ ├── svm_struct.cpp │ │ ├── testing_results.h │ │ └── vector.cpp │ └── test/ │ ├── .gitignore │ ├── generate_numpy_returns_test_data.py │ ├── shape.pkl │ ├── test_array.py │ ├── test_chinese_whispers.py │ ├── test_global_optimization.py │ ├── test_matrix.py │ ├── test_numpy_returns.py │ ├── test_point.py │ ├── test_range.py │ ├── test_rgb_pixel.py │ ├── test_sparse_vector.py │ ├── test_svm_c_trainer.py │ ├── test_vector.py │ └── utils.py └── visual_studio_natvis/ ├── README.txt └── dlib.natvis ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug.yml ================================================ name: Bug Report description: Create a bug report. title: "[Bug]: " assignees: [] body: - type: markdown attributes: value: | Before you ask a question, check Google for a solution, [the dlib FAQ](http://dlib.net/faq.html), or consult the dlib documentation. Every single function in dlib is documented in detail. If you obviously haven't read the documentation your issue will be closed. - type: dropdown id: platform attributes: label: What Operating System(s) are you seeing this problem on? description: Select all that apply multiple: true options: - Linux (x86-64) - Linux (aarch64) - macOS (Intel) - macOS (Apple Silicon) - Windows - Other (plase, specify in the Steps to Reproduce) validations: required: true - type: input id: version attributes: label: dlib version description: "The commit hash, tag name, or package version" placeholder: "19.24" validations: required: true - type: input id: python attributes: label: Python version description: "The version of Python you are using" placeholder: "3.8" validations: required: false - type: input id: compiler attributes: label: Compiler description: "The compiler and version you used to build dlib from source" placeholder: "e.g.: GCC 9, MSVC 19, clang 10" validations: required: true - type: textarea id: expected attributes: label: Expected Behavior description: | A clear and concise description of what you expected to happen - type: textarea id: current attributes: label: Current Behavior description: A clear and concise description of what the bug is placeholder: | Tell us what happened. Expain in detail the current behavior. validations: required: true - type: textarea id: reproduce validations: required: true attributes: label: Steps to Reproduce description: Steps to reproduce the behavior placeholder: | Please include as much information as possible that can help to reproduce and understand the issue. Including a minimal example that reproduces the bug is very useful. - type: textarea id: other attributes: label: Anything else? description: | Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: false contact_links: - name: 💬 Discussions url: https://github.com/davisking/dlib/discussions about: Please start usage discussions here ================================================ FILE: .github/ISSUE_TEMPLATE/feature.yml ================================================ name: Feature Request description: Suggest an idea for dlib labels: type:enhancement assignees: [] body: - type: markdown attributes: value: | It is OK to suggest interesting improvements to dlib, even if you are not volunteering to implement them. **However, the issue tracker is not a code writing service, do not ask for someone to write code for you**. E.g. Do not ask for feature improvements to the example programs. **If there is some feature improvement you want in an example program then it's up to you to write it**. - type: textarea id: description attributes: label: Main idea description: Describe what you want to add or improve validations: required: true - type: textarea id: other attributes: label: Anything else? description: Add any other context about the suggestion here. ================================================ FILE: .github/workflows/build_cpp.yml ================================================ name: C++ on: push: branches: - master paths: - ".github/workflows/build_cpp.yml" - "**.cpp" - "**.h" - "**.c" - "**.cu" - "**.cmake" - "**CMakeLists.txt" pull_request: branches: - master paths: - ".github/workflows/build_cpp.yml" - "**.cpp" - "**.h" - "**.c" - "**.cu" - "**.cmake" - "**CMakeLists.txt" defaults: run: shell: bash working-directory: dlib/test jobs: ubuntu-22-04-gcc-default-cmake-3-17-ffmpeg5: runs-on: 'ubuntu-22.04' steps: - uses: actions/checkout@v2 - name: Install dependencies run: | sudo apt update sudo apt install libwebp-dev make yasm - name: Cache cmake 3.17.0 uses: actions/cache@v3 id: cache-cmake-download with: # cache this folder: path: ~/cmake-3.17.0-Linux-x86_64 key: cmake-3.17.0_try3 - run: | # Get the minimum version of cmake dlib supports wget https://cmake.org/files/v3.17/cmake-3.17.0-Linux-x86_64.tar.gz tar -xf cmake-3.17.0-Linux-x86_64.tar.gz -C ~ if: steps.cache-cmake-download.outputs.cache-hit != 'true' - name: Cache FFmpeg 5 uses: actions/cache@v3 id: cache-ffmpeg5 with: path: /home/runner/ffmpeg-n5.1.3_installation key: ffmpeg-n5.1.3_try4 - name: Build FFmpeg 5 if: steps.cache-ffmpeg5.outputs.cache-hit != 'true' run: | wget https://github.com/FFmpeg/FFmpeg/archive/refs/tags/n5.1.3.tar.gz tar -xf n5.1.3.tar.gz cd FFmpeg-n5.1.3 ./configure --prefix=/home/runner/ffmpeg-n5.1.3_installation --disable-doc --disable-programs make -j4 make install cd .. - name: Configure run: | mkdir build cd build ~/cmake-3.17.0-Linux-x86_64/bin/cmake -DCMAKE_PREFIX_PATH=/home/runner/ffmpeg-n5.1.3_installation .. - name: Build just tests run: | cd build make -j4 dtest - name: Test run: build/dtest --runall -q - name: Build examples, etc run: | cd build make -j2 ubuntu-latest-gcc-11-blas: runs-on: 'ubuntu-latest' steps: - uses: actions/checkout@v2 - name: Install dependencies run: | sudo apt update sudo apt install libwebp-dev libavformat-dev libavcodec-dev libavdevice-dev libavfilter-dev libswresample-dev libswscale-dev libavutil-dev sudo apt install libopenblas-dev liblapack-dev - name: Install gcc 11 run: | sudo apt install gcc-11 g++-11 sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 110 --slave /usr/bin/g++ g++ /usr/bin/g++-11 --slave /usr/bin/gcov gcov /usr/bin/gcov-11 - name: Configure run: cmake ${{ github.workspace }}/dlib/test -B build - name: Build just tests run: cmake --build build --config Release --target dtest --parallel 4 - name: Test run: build/dtest --runall -q - name: Build examples, etc run: cmake --build build --config Release --parallel 2 # Test the BLAS bindings - name: Configure BLAS binding tests run: cmake ${{ github.workspace }}/dlib/test/blas_bindings -B build_blas_bindings - name: Build blas binding tests run: cmake --build build_blas_bindings --config Debug --parallel 4 - name: Test BLAS bindings run: build_blas_bindings/dtest --runall -q ubuntu-latest-clang-default-avx: runs-on: 'ubuntu-latest' steps: - uses: actions/checkout@v2 - name: Install dependencies run: | sudo apt update sudo apt install libwebp-dev libavformat-dev libavcodec-dev libavdevice-dev libavfilter-dev libswresample-dev libswscale-dev libavutil-dev - name: Configure run: | export CC=/usr/bin/clang export CXX=/usr/bin/clang++ cmake ${{ github.workspace }}/dlib/test -B build -DUSE_AVX_INSTRUCTIONS=1 - name: Build just tests run: cmake --build build --config Release --target dtest --parallel 4 - name: Test run: build/dtest --runall -q - name: Build examples, etc run: cmake --build build --config Release --parallel 2 ubuntu-22-04-ffmpeg701: runs-on: 'ubuntu-22.04' steps: - uses: actions/checkout@v2 - name: Install dependencies run: | sudo apt update sudo apt install make yasm - name: Cache FFmpeg 7 uses: actions/cache@v3 id: cache-ffmpeg7 with: path: /home/runner/ffmpeg-n7.0.1_installation key: ffmpeg-n7.0.1_try2 - name: Build FFmpeg 7 if: steps.cache-ffmpeg7.outputs.cache-hit != 'true' run: | wget https://github.com/FFmpeg/FFmpeg/archive/refs/tags/n7.0.1.tar.gz tar -xf n7.0.1.tar.gz cd FFmpeg-n7.0.1 ./configure --prefix=/home/runner/ffmpeg-n7.0.1_installation --disable-doc --disable-programs make -j4 make install cd .. - name: Configure run: cmake . -B build -DCMAKE_PREFIX_PATH=/home/runner/ffmpeg-n7.0.1_installation - name: Build ffmpeg example run: cmake --build build --config Release --target ffmpeg_video_muxing_ex --parallel 4 ubuntu-22-04-ffmpeg711: runs-on: 'ubuntu-22.04' steps: - uses: actions/checkout@v2 - name: Install dependencies run: | sudo apt update sudo apt install make yasm - name: Cache FFmpeg 7 uses: actions/cache@v3 id: cache-ffmpeg7 with: path: /home/runner/ffmpeg-n7.1.1_installation key: ffmpeg-n7.1.1_try1 - name: Build FFmpeg 7 if: steps.cache-ffmpeg7.outputs.cache-hit != 'true' run: | wget https://github.com/FFmpeg/FFmpeg/archive/refs/tags/n7.1.1.tar.gz tar -xf n7.1.1.tar.gz cd FFmpeg-n7.1.1 ./configure --prefix=/home/runner/ffmpeg-n7.1.1_installation --disable-doc --disable-programs make -j4 make install cd .. - name: Configure run: cmake . -B build -DCMAKE_PREFIX_PATH=/home/runner/ffmpeg-n7.1.1_installation - name: Build ffmpeg example run: cmake --build build --config Release --target ffmpeg_video_muxing_ex --parallel 4 windows-latest: runs-on: 'windows-latest' steps: - uses: actions/checkout@v2 - name: Configure run: | # don't use CMake 3.25.0 https://gitlab.kitware.com/cmake/cmake/-/issues/23975 pip3 install cmake==3.24.0 cmake . -B build - name: Build just tests run: cmake --build build --config Release --target dtest --parallel 4 - name: Test run: build/Release/dtest.exe --runall -q - name: Build ancillary tools run: cmake --build build --config Release --target imglab htmlify dtoc --parallel 4 # Disable this because macos targets aren't working on github actions right now. #macos-latest: # runs-on: 'macos-latest' # steps: # - uses: actions/checkout@v2 # - name: Configure # # MacOS machines often come with low quality BLAS libraries installed, so don't use those. # run: cmake ${{ github.workspace }}/dlib/test -B build -DDLIB_USE_BLAS=0 -DDLIB_USE_LAPACK=0 # - name: Build just tests # run: cmake --build build --config Release --target dtest --parallel 4 # - name: Test # run: build/dtest --runall --no_test_timer -q # - name: Build examples, etc # run: cmake --build build --config Release --parallel 2 ================================================ FILE: .github/workflows/build_matlab.yml ================================================ name: Matlab on: push: branches: - master paths: - ".github/workflows/build_matlab.yml" - "**.cpp" - "**.h" - "**.c" - "**.cu" - "**.cmake" - "**CMakeLists.txt" pull_request: branches: - master paths: - ".github/workflows/build_matlab.yml" - "**.cpp" - "**.h" - "**.c" - "**.cu" - "**.cmake" - "**CMakeLists.txt" defaults: run: shell: bash working-directory: dlib/matlab jobs: mex-wrapper: runs-on: 'ubuntu-22.04' steps: - uses: actions/checkout@v2 - name: Setup MATLAB uses: matlab-actions/setup-matlab@v1 - name: Compile mex wrappers run: | pwd mkdir build cd build cmake .. cmake --build . --config Release --target install --parallel 4 - name: Run example script uses: matlab-actions/run-command@v1 with: command: cd dlib/matlab; example ================================================ FILE: .github/workflows/build_python.yml ================================================ name: Python on: push: branches: - master paths: - ".github/workflows/build_python.yml" - "**.cpp" - "**.h" - "**.c" - "**.cu" - "**.cmake" - "**CMakeLists.txt" - "**py" pull_request: branches: - master paths: - ".github/workflows/build_python.yml" - "**.cpp" - "**.h" - "**.c" - "**.cu" - "**.cmake" - "**CMakeLists.txt" - "**py" defaults: run: shell: bash jobs: Windows: runs-on: 'windows-latest' steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 - name: Install python deps run: | pip install pytest numpy # don't use CMake 3.25.0 https://gitlab.kitware.com/cmake/cmake/-/issues/23975 pip3 install cmake==3.24.0 - name: Build run: | pip3 install cmake==3.24.0 pip3 install . - name: Test run: python -m pytest --ignore docs --ignore dlib Ubuntu: runs-on: 'ubuntu-latest' steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 - name: Install python deps run: pip install pytest numpy - name: Build run: | pip install . - name: Test run: python -m pytest --ignore docs --ignore dlib # Disabled for now since something is going sideways with python packages on github actions # MacOS: # runs-on: 'macos-latest' # steps: # - uses: actions/checkout@v3 # - uses: actions/setup-python@v4 # - name: Install python deps # run: pip3 install pytest numpy # - name: Build # run: | # pip install . # - name: Test # run: python3 -m pytest --ignore docs --ignore dlib ================================================ FILE: .gitignore ================================================ **/.idea *~ *.swp *.o *.so *.pyc build build2 dist *.egg-info/ docs/release/ docs/docs/web/ docs/docs/chm/ docs/docs/cache/ docs/docs/git-logs.xml docs/docs/python/classes.txt docs/docs/python/functions.txt docs/docs/python/constants.txt **/.vscode **/venv ================================================ FILE: CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.17.0) project(dlib_project) ############################################################################# # # # READ examples/CMakeLists.txt TO SEE HOW TO USE DLIB FROM C++ WITH CMAKE # # # ############################################################################# get_directory_property(has_parent PARENT_DIRECTORY) if(NOT has_parent) # When you call add_subdirectory(dlib) from a parent CMake project dlib's # CMake scripts will assume you want to statically compile dlib into # whatever you are building rather than create a standalone copy of dlib. # This means CMake will build dlib as a static library, disable dlib's # install targets so they don't clutter your project, and adjust a few other # minor things that are convenient when statically building dlib as part of # your own projects. # # On the other hand, if there is no parent CMake project or if # DLIB_IN_PROJECT_BUILD is set to false, CMake will compile dlib as a normal # standalone library (either shared or static, based on the state of CMake's # BUILD_SHARED_LIBS flag), and include the usual install targets so you can # install dlib on your computer via `make install`. Since the only reason # to build this CMakeLists.txt (the one you are reading right now) by itself # is if you want to install dlib, we indicate as such by setting # DLIB_IN_PROJECT_BUILD to false. set(DLIB_IN_PROJECT_BUILD false) endif() add_subdirectory(dlib) ================================================ FILE: LICENSE.txt ================================================ Boost Software License - Version 1.0 - August 17th, 2003 Permission is hereby granted, free of charge, to any person or organization obtaining a copy of the software and accompanying documentation covered by this license (the "Software") to use, reproduce, display, distribute, execute, and transmit the Software, and to prepare derivative works of the Software, and to permit third-parties to whom the Software is furnished to do so, all subject to the following: The copyright notices in the Software and this entire statement, including the above license grant, this restriction and the following disclaimer, must be included in all copies of the Software, in whole or in part, and all derivative works of the Software, unless such copies or derivative works are solely in the form of machine-executable object code generated by a source language processor. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: MANIFEST.in ================================================ # # MANIFEST.in # # Manifest template for creating the dlib source distribution. include MANIFEST.in include setup.py include README.md # sources recursive-include dlib ** recursive-include python_examples *.txt *.py recursive-include tools/python ** prune tools/python/build* prune dlib/cmake_utils/*/build* prune dlib/test global-exclude *.pyc ================================================ FILE: README.md ================================================ # dlib C++ library [![GitHub Actions C++ Status](https://github.com/davisking/dlib/actions/workflows/build_cpp.yml/badge.svg)](https://github.com/davisking/dlib/actions/workflows/build_cpp.yml) [![GitHub Actions Python Status](https://github.com/davisking/dlib/actions/workflows/build_python.yml/badge.svg)](https://github.com/davisking/dlib/actions/workflows/build_python.yml) Dlib is a modern C++ toolkit containing machine learning algorithms and tools for creating complex software in C++ to solve real world problems. See [http://dlib.net](http://dlib.net) for the main project documentation and API reference. ## Compiling dlib C++ example programs Go into the examples folder and type: ```bash mkdir build; cd build; cmake .. ; cmake --build . ``` That will build all the examples. If you have a CPU that supports AVX instructions then turn them on like this: ```bash mkdir build; cd build; cmake .. -DUSE_AVX_INSTRUCTIONS=1; cmake --build . ``` Doing so will make some things run faster. Finally, Visual Studio users should usually do everything in 64bit mode. By default Visual Studio is 32bit, both in its outputs and its own execution, so you have to explicitly tell it to use 64bits. Since it's not the 1990s anymore you probably want to use 64bits. Do that with a cmake invocation like this: ```bash cmake .. -G "Visual Studio 14 2015 Win64" -T host=x64 ``` ## Compiling your own C++ programs that use dlib The examples folder has a [CMake tutorial](https://github.com/davisking/dlib/blob/master/examples/CMakeLists.txt) that tells you what to do. There are also additional instructions on the [dlib web site](http://dlib.net/compile.html). Alternatively, if you are using the [vcpkg](https://github.com/Microsoft/vcpkg/) dependency manager you can download and install dlib with CMake integration in a single command: ```bash vcpkg install dlib ``` ## Compiling dlib Python API Either fetch the latest stable release of dlib from PyPi and install that: ```bash pip install dlib ``` Or fetch the very latest version from github and install that: ```bash git clone https://github.com/davisking/dlib.git cd dlib pip install . ``` It's possible to change build settings by passing parameters to `setup.py` or `DLIB_*` environment variables. For example, setting the environment variable `DLIB_NO_GUI_SUPPORT` to `ON` will add the cmake option `-DDLIB_NO_GUI_SUPPORT=ON`. ## Running the unit test suite Type the following to compile and run the dlib unit test suite: ```bash cd dlib/test mkdir build cd build cmake .. cmake --build . --config Release ./dtest --runall ``` Note that on windows your compiler might put the test executable in a subfolder called `Release`. If that's the case then you have to go to that folder before running the test. This library is licensed under the Boost Software License, which can be found in [dlib/LICENSE.txt](https://github.com/davisking/dlib/blob/master/dlib/LICENSE.txt). The long and short of the license is that you can use dlib however you like, even in closed source commercial software. ## dlib sponsors This research is based in part upon work supported by the Office of the Director of National Intelligence (ODNI), Intelligence Advanced Research Projects Activity (IARPA) under contract number 2014-14071600010. The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either expressed or implied, of ODNI, IARPA, or the U.S. Government. ================================================ FILE: dlib/CMakeLists.txt ================================================ # # This is a CMake makefile. You can find the cmake utility and # information about it at http://www.cmake.org # cmake_minimum_required(VERSION 3.17.0) set(CMAKE_DISABLE_SOURCE_CHANGES ON) set(CMAKE_DISABLE_IN_SOURCE_BUILD ON) if(POLICY CMP0077) cmake_policy(SET CMP0077 NEW) endif() project(dlib LANGUAGES C CXX) set(CPACK_PACKAGE_NAME "dlib") set(CPACK_PACKAGE_VERSION_MAJOR "20") set(CPACK_PACKAGE_VERSION_MINOR "0") set(CPACK_PACKAGE_VERSION_PATCH "99") set(VERSION ${CPACK_PACKAGE_VERSION_MAJOR}.${CPACK_PACKAGE_VERSION_MINOR}.${CPACK_PACKAGE_VERSION_PATCH}) # Only print these messages once, even if dlib is added multiple times via add_subdirectory() if (NOT TARGET dlib) message(STATUS "Using CMake version: ${CMAKE_VERSION}") message(STATUS "Compiling dlib version: ${VERSION}") endif() set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake_utils) include(cmake_utils/set_compiler_specific_options.cmake) # Adhere to GNU filesystem layout conventions include(GNUInstallDirs) if (POLICY CMP0075) cmake_policy(SET CMP0075 NEW) endif() # default to a Release build (except if CMAKE_BUILD_TYPE is set) include(cmake_utils/release_build_by_default) # Set DLIB_VERSION in the including CMake file so they can use it to do whatever they want. get_directory_property(has_parent PARENT_DIRECTORY) if(has_parent) set(DLIB_VERSION ${VERSION} PARENT_SCOPE) if (NOT DEFINED DLIB_IN_PROJECT_BUILD) set(DLIB_IN_PROJECT_BUILD true) endif() endif() if (COMMAND pybind11_add_module AND MSVC) # True when building a python extension module using Visual Studio. We care # about this because a huge number of windows users have broken systems, and # in particular, they have broken or incompatibly installed copies of things # like libjpeg or libpng. So if we detect we are in this mode we will never # ever link to those libraries. Instead, we link to the copy included with # dlib. set (BUILDING_PYTHON_IN_MSVC true) else() set (BUILDING_PYTHON_IN_MSVC false) endif() if (DLIB_IN_PROJECT_BUILD) # Check if we are being built as part of a pybind11 module. if (COMMAND pybind11_add_module) set(CMAKE_POSITION_INDEPENDENT_CODE True) if (CMAKE_COMPILER_IS_GNUCXX) # Just setting CMAKE_POSITION_INDEPENDENT_CODE should be enough to set # -fPIC for GCC but sometimes it still doesn't get set, so make sure it # does. add_definitions("-fPIC") endif() # Make DLIB_ASSERT statements not abort the python interpreter, but just return an error. list(APPEND active_preprocessor_switches "-DDLIB_NO_ABORT_ON_2ND_FATAL_ERROR") endif() # DLIB_IN_PROJECT_BUILD==true means you are using dlib by invoking # add_subdirectory(dlib) in the parent project. In this case, we always want # to build dlib as a static library so the parent project doesn't need to # deal with some random dlib shared library file. It is much better to # statically compile dlib into the parent project. So the following bit of # CMake ensures that happens. However, we have to take care to compile dlib # with position independent code if appropriate (i.e. if the parent project # is a shared library). if (BUILD_SHARED_LIBS) if (CMAKE_COMPILER_IS_GNUCXX) # Just setting CMAKE_POSITION_INDEPENDENT_CODE should be enough to set # -fPIC for GCC but sometimes it still doesn't get set, so make sure it # does. add_definitions("-fPIC") endif() set(CMAKE_POSITION_INDEPENDENT_CODE true) endif() # Tell cmake to build dlib as a static library set(BUILD_SHARED_LIBS false) elseif(BUILD_SHARED_LIBS) if (MSVC) message(FATAL_ERROR "Building dlib as a standalone dll is not supported when using Visual Studio. You are highly encouraged to use static linking instead. See https://github.com/davisking/dlib/issues/1483 for a discussion.") endif() endif() macro (enable_preprocessor_switch option_name) list(APPEND active_preprocessor_switches "-D${option_name}") endmacro() macro (disable_preprocessor_switch option_name) if (active_preprocessor_switches) list(REMOVE_ITEM active_preprocessor_switches "-D${option_name}") endif() endmacro() macro (toggle_preprocessor_switch option_name) if (${option_name}) enable_preprocessor_switch(${option_name}) else() disable_preprocessor_switch(${option_name}) endif() endmacro() # Suppress superfluous randlib warnings about libdlib.a having no symbols on MacOSX. if (CMAKE_C_COMPILER_ID STREQUAL "AppleClang") set(CMAKE_C_ARCHIVE_CREATE " Scr ") set(CMAKE_C_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") endif() if (CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") set(CMAKE_CXX_ARCHIVE_CREATE " Scr ") set(CMAKE_CXX_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") endif() # Don't try to call add_library(dlib) and setup dlib's stuff if it has already # been done by some other part of the current cmake project. We do this # because it avoids getting warnings/errors about cmake policy CMP0002. This # happens when a project tries to call add_subdirectory() on dlib more than # once. This most often happens when the top level of a project depends on two # or more other things which both depend on dlib. if (NOT TARGET dlib) set (DLIB_ISO_CPP_ONLY_STR "Enable this if you don't want to compile any non-ISO C++ code (i.e. you don't use any of the API Wrappers)" ) set (DLIB_NO_GUI_SUPPORT_STR "Enable this if you don't want to compile any of the dlib GUI code" ) set (DLIB_ENABLE_STACK_TRACE_STR "Enable this if you want to turn on the DLIB_STACK_TRACE macros" ) set (DLIB_USE_BLAS_STR "Disable this if you don't want to use a BLAS library" ) set (DLIB_USE_LAPACK_STR "Disable this if you don't want to use a LAPACK library" ) set (DLIB_USE_CUDA_STR "Disable this if you don't want to use NVIDIA CUDA" ) set (DLIB_USE_CUDA_COMPUTE_CAPABILITIES_STR "Set this to a comma-separated list of CUDA compute capabilities" ) set (DLIB_USE_MKL_SEQUENTIAL_STR "Enable this if you have MKL installed and want to use the sequential version instead of the multi-core version." ) set (DLIB_USE_MKL_WITH_TBB_STR "Enable this if you have MKL installed and want to use the tbb version instead of the openmp version." ) set (DLIB_PNG_SUPPORT_STR "Disable this if you don't want to link against libpng" ) set (DLIB_GIF_SUPPORT_STR "Disable this if you don't want to link against libgif" ) set (DLIB_JPEG_SUPPORT_STR "Disable this if you don't want to link against libjpeg" ) set (DLIB_WEBP_SUPPORT_STR "Disable this if you don't want to link against libwebp" ) set (DLIB_JXL_SUPPORT_STR "Disable this if you don't want to link against libjxl" ) set (DLIB_LINK_WITH_SQLITE3_STR "Disable this if you don't want to link against sqlite3" ) #set (DLIB_USE_FFTW_STR "Disable this if you don't want to link against fftw" ) set (DLIB_USE_MKL_FFT_STR "Disable this is you don't want to use the MKL DFTI FFT implementation" ) set (DLIB_ENABLE_ASSERTS_STR "Enable this if you want to turn on the DLIB_ASSERT macro" ) set (DLIB_USE_FFMPEG_STR "Disable this if you don't want to use the FFMPEG library" ) option(DLIB_ENABLE_ASSERTS ${DLIB_ENABLE_ASSERTS_STR} OFF) option(DLIB_ISO_CPP_ONLY ${DLIB_ISO_CPP_ONLY_STR} OFF) toggle_preprocessor_switch(DLIB_ISO_CPP_ONLY) option(DLIB_NO_GUI_SUPPORT ${DLIB_NO_GUI_SUPPORT_STR} OFF) toggle_preprocessor_switch(DLIB_NO_GUI_SUPPORT) option(DLIB_ENABLE_STACK_TRACE ${DLIB_ENABLE_STACK_TRACE_STR} OFF) toggle_preprocessor_switch(DLIB_ENABLE_STACK_TRACE) option(DLIB_USE_MKL_SEQUENTIAL ${DLIB_USE_MKL_SEQUENTIAL_STR} OFF) option(DLIB_USE_MKL_WITH_TBB ${DLIB_USE_MKL_WITH_TBB_STR} OFF) if(DLIB_ENABLE_ASSERTS) # Set these variables so they are set in the config.h.in file when dlib # is installed. set (DLIB_DISABLE_ASSERTS false) set (ENABLE_ASSERTS true) enable_preprocessor_switch(ENABLE_ASSERTS) disable_preprocessor_switch(DLIB_DISABLE_ASSERTS) else() # Set these variables so they are set in the config.h.in file when dlib # is installed. set (DLIB_DISABLE_ASSERTS true) set (ENABLE_ASSERTS false) disable_preprocessor_switch(ENABLE_ASSERTS) # Never force the asserts off when doing an in project build. The only # time this matters is when using visual studio. The visual studio IDE # has a drop down that lets the user select either release or debug # builds. The DLIB_ASSERT macro is setup to enable/disable automatically # based on this drop down (via preprocessor magic). However, if # DLIB_DISABLE_ASSERTS is defined it permanently disables asserts no # matter what, which would defeat the visual studio drop down. So here # we make a point to not do that kind of severe disabling when in a # project build. It should also be pointed out that DLIB_DISABLE_ASSERTS # is only needed when building and installing dlib as a separately # installed library. It doesn't matter when doing an in project build. if (NOT DLIB_IN_PROJECT_BUILD) enable_preprocessor_switch(DLIB_DISABLE_ASSERTS) endif() endif() if (DLIB_ISO_CPP_ONLY) option(DLIB_JPEG_SUPPORT ${DLIB_JPEG_SUPPORT_STR} OFF) option(DLIB_LINK_WITH_SQLITE3 ${DLIB_LINK_WITH_SQLITE3_STR} OFF) option(DLIB_USE_BLAS ${DLIB_USE_BLAS_STR} OFF) option(DLIB_USE_LAPACK ${DLIB_USE_LAPACK_STR} OFF) option(DLIB_USE_CUDA ${DLIB_USE_CUDA_STR} OFF) option(DLIB_PNG_SUPPORT ${DLIB_PNG_SUPPORT_STR} OFF) option(DLIB_GIF_SUPPORT ${DLIB_GIF_SUPPORT_STR} OFF) option(DLIB_WEBP_SUPPORT ${DLIB_WEBP_SUPPORT_STR} OFF) option(DLIB_JXL_SUPPORT ${DLIB_JXL_SUPPORT_STR} OFF) #option(DLIB_USE_FFTW ${DLIB_USE_FFTW_STR} OFF) option(DLIB_USE_MKL_FFT ${DLIB_USE_MKL_FFT_STR} OFF) option(DLIB_USE_FFMPEG ${DLIB_USE_FFMPEG_STR} OFF) else() option(DLIB_JPEG_SUPPORT ${DLIB_JPEG_SUPPORT_STR} ON) option(DLIB_LINK_WITH_SQLITE3 ${DLIB_LINK_WITH_SQLITE3_STR} ON) option(DLIB_USE_BLAS ${DLIB_USE_BLAS_STR} ON) option(DLIB_USE_LAPACK ${DLIB_USE_LAPACK_STR} ON) option(DLIB_USE_CUDA ${DLIB_USE_CUDA_STR} ON) option(DLIB_PNG_SUPPORT ${DLIB_PNG_SUPPORT_STR} ON) option(DLIB_GIF_SUPPORT ${DLIB_GIF_SUPPORT_STR} ON) option(DLIB_WEBP_SUPPORT ${DLIB_WEBP_SUPPORT_STR} ON) option(DLIB_JXL_SUPPORT ${DLIB_JXL_SUPPORT_STR} ON) #option(DLIB_USE_FFTW ${DLIB_USE_FFTW_STR} ON) option(DLIB_USE_MKL_FFT ${DLIB_USE_MKL_FFT_STR} ON) option(DLIB_USE_FFMPEG ${DLIB_USE_FFMPEG_STR} ON) endif() toggle_preprocessor_switch(DLIB_JPEG_SUPPORT) toggle_preprocessor_switch(DLIB_USE_BLAS) toggle_preprocessor_switch(DLIB_USE_LAPACK) toggle_preprocessor_switch(DLIB_USE_CUDA) toggle_preprocessor_switch(DLIB_PNG_SUPPORT) toggle_preprocessor_switch(DLIB_GIF_SUPPORT) toggle_preprocessor_switch(DLIB_WEBP_SUPPORT) toggle_preprocessor_switch(DLIB_JXL_SUPPORT) #toggle_preprocessor_switch(DLIB_USE_FFTW) toggle_preprocessor_switch(DLIB_USE_MKL_FFT) toggle_preprocessor_switch(DLIB_USE_FFMPEG) set(source_files base64/base64_kernel_1.cpp bigint/bigint_kernel_1.cpp bigint/bigint_kernel_2.cpp bit_stream/bit_stream_kernel_1.cpp entropy_decoder/entropy_decoder_kernel_1.cpp entropy_decoder/entropy_decoder_kernel_2.cpp entropy_encoder/entropy_encoder_kernel_1.cpp entropy_encoder/entropy_encoder_kernel_2.cpp md5/md5_kernel_1.cpp tokenizer/tokenizer_kernel_1.cpp unicode/unicode.cpp test_for_odr_violations.cpp fft/fft.cpp ) set(dlib_needed_public_libraries) set(dlib_needed_public_includes) set(dlib_needed_public_cflags) set(dlib_needed_public_ldflags) set(dlib_needed_private_libraries) set(dlib_needed_private_includes) if (DLIB_ISO_CPP_ONLY) add_library(dlib ${source_files} ) else() set(source_files ${source_files} sockets/sockets_kernel_1.cpp bsp/bsp.cpp dir_nav/dir_nav_kernel_1.cpp dir_nav/dir_nav_kernel_2.cpp dir_nav/dir_nav_extensions.cpp gui_widgets/fonts.cpp linker/linker_kernel_1.cpp logger/extra_logger_headers.cpp logger/logger_kernel_1.cpp logger/logger_config_file.cpp misc_api/misc_api_kernel_1.cpp misc_api/misc_api_kernel_2.cpp sockets/sockets_extensions.cpp sockets/sockets_kernel_2.cpp sockstreambuf/sockstreambuf.cpp sockstreambuf/sockstreambuf_unbuffered.cpp server/server_kernel.cpp server/server_iostream.cpp server/server_http.cpp threads/multithreaded_object_extension.cpp threads/threaded_object_extension.cpp threads/threads_kernel_1.cpp threads/threads_kernel_2.cpp threads/threads_kernel_shared.cpp threads/thread_pool_extension.cpp threads/async.cpp timer/timer.cpp stack_trace.cpp cuda/cpu_dlib.cpp cuda/tensor_tools.cpp data_io/image_dataset_metadata.cpp data_io/mnist.cpp data_io/cifar.cpp global_optimization/global_function_search.cpp filtering/kalman_filter.cpp svm/auto.cpp ) if(UNIX) set(CMAKE_THREAD_PREFER_PTHREAD ON) find_package(Threads REQUIRED) list (APPEND dlib_needed_private_libraries ${CMAKE_THREAD_LIBS_INIT}) endif() # we want to link to the right stuff depending on our platform. if (WIN32 AND NOT CYGWIN) ############################################################################### if (DLIB_NO_GUI_SUPPORT) list (APPEND dlib_needed_private_libraries ws2_32 winmm) else() list (APPEND dlib_needed_private_libraries ws2_32 winmm comctl32 gdi32 imm32) endif() elseif(APPLE) ############################################################################ set(CMAKE_MACOSX_RPATH 1) if (NOT DLIB_NO_GUI_SUPPORT) find_package(X11 QUIET) if (X11_FOUND) # If both X11 and anaconda are installed, it's possible for the # anaconda path to appear before /opt/X11, so we remove anaconda. foreach (ITR ${X11_INCLUDE_DIR}) if ("${ITR}" MATCHES "(.*)(Ana|ana|mini)conda(.*)") list (REMOVE_ITEM X11_INCLUDE_DIR ${ITR}) endif () endforeach(ITR) list (APPEND dlib_needed_public_includes ${X11_INCLUDE_DIR}) list (APPEND dlib_needed_public_libraries ${X11_LIBRARIES}) else() find_library(xlib X11) # Make sure X11 is in the include path. Note that we look for # Xlocale.h rather than Xlib.h because it avoids finding a partial # copy of the X11 headers on systems with anaconda installed. find_path(xlib_path Xlocale.h PATHS /Developer/SDKs/MacOSX10.4u.sdk/usr/X11R6/include /opt/local/include PATH_SUFFIXES X11 ) if (xlib AND xlib_path) get_filename_component(x11_path ${xlib_path} PATH CACHE) list (APPEND dlib_needed_public_includes ${x11_path}) list (APPEND dlib_needed_public_libraries ${xlib}) set(X11_FOUND 1) endif() endif() if (NOT X11_FOUND) message(" *****************************************************************************") message(" *** DLIB GUI SUPPORT DISABLED BECAUSE X11 DEVELOPMENT LIBRARIES NOT FOUND ***") message(" *** Make sure XQuartz is installed if you want GUI support. ***") message(" *** You can download XQuartz from: https://www.xquartz.org/ ***") message(" *****************************************************************************") set(DLIB_NO_GUI_SUPPORT ON CACHE STRING ${DLIB_NO_GUI_SUPPORT_STR} FORCE ) enable_preprocessor_switch(DLIB_NO_GUI_SUPPORT) endif() endif() mark_as_advanced(xlib xlib_path x11_path) else () ################################################################################## # link to the socket library if it exists. this is something you need on solaris find_library(socketlib socket) if (socketlib) list (APPEND dlib_needed_private_libraries ${socketlib}) endif () if (NOT DLIB_NO_GUI_SUPPORT) include(FindX11) if (X11_FOUND) list (APPEND dlib_needed_private_includes ${X11_INCLUDE_DIR}) list (APPEND dlib_needed_private_libraries ${X11_LIBRARIES}) else() message(" *****************************************************************************") message(" *** DLIB GUI SUPPORT DISABLED BECAUSE X11 DEVELOPMENT LIBRARIES NOT FOUND ***") message(" *** Make sure libx11-dev is installed if you want GUI support. ***") message(" *** On Ubuntu run: sudo apt-get install libx11-dev ***") message(" *****************************************************************************") set(DLIB_NO_GUI_SUPPORT ON CACHE STRING ${DLIB_NO_GUI_SUPPORT_STR} FORCE ) enable_preprocessor_switch(DLIB_NO_GUI_SUPPORT) endif() endif() mark_as_advanced(nsllib socketlib) endif () ################################################################################## if (NOT DLIB_NO_GUI_SUPPORT) set(source_files ${source_files} gui_widgets/widgets.cpp gui_widgets/drawable.cpp gui_widgets/canvas_drawing.cpp gui_widgets/style.cpp gui_widgets/base_widgets.cpp gui_core/gui_core_kernel_1.cpp gui_core/gui_core_kernel_2.cpp ) endif() INCLUDE (CheckFunctionExists) if (DLIB_GIF_SUPPORT) find_package(GIF QUIET) if (GIF_FOUND) list (APPEND dlib_needed_public_includes ${GIF_INCLUDE_DIR}) list (APPEND dlib_needed_public_libraries ${GIF_LIBRARY}) else() set(DLIB_GIF_SUPPORT OFF CACHE STRING ${DLIB_GIF_SUPPORT_STR} FORCE ) toggle_preprocessor_switch(DLIB_GIF_SUPPORT) endif() endif() if (DLIB_PNG_SUPPORT) include(cmake_utils/find_libpng.cmake) if (PNG_FOUND) list (APPEND dlib_needed_private_includes ${PNG_INCLUDE_DIR}) list (APPEND dlib_needed_private_libraries ${PNG_LIBRARIES}) else() # If we can't find libpng then statically compile it in. include_directories(external/libpng external/zlib) set(source_files ${source_files} external/libpng/arm/arm_init.c external/libpng/arm/filter_neon_intrinsics.c external/libpng/arm/palette_neon_intrinsics.c external/libpng/png.c external/libpng/pngerror.c external/libpng/pngget.c external/libpng/pngmem.c external/libpng/pngpread.c external/libpng/pngread.c external/libpng/pngrio.c external/libpng/pngrtran.c external/libpng/pngrutil.c external/libpng/pngset.c external/libpng/pngtrans.c external/libpng/pngwio.c external/libpng/pngwrite.c external/libpng/pngwtran.c external/libpng/pngwutil.c external/zlib/adler32.c external/zlib/compress.c external/zlib/crc32.c external/zlib/deflate.c external/zlib/gzclose.c external/zlib/gzlib.c external/zlib/gzread.c external/zlib/gzwrite.c external/zlib/infback.c external/zlib/inffast.c external/zlib/inflate.c external/zlib/inftrees.c external/zlib/trees.c external/zlib/uncompr.c external/zlib/zutil.c ) include(cmake_utils/check_if_neon_available.cmake) if (ARM_NEON_IS_AVAILABLE) message (STATUS "NEON instructions will be used for libpng.") enable_language(ASM) set(source_files ${source_files} external/libpng/arm/arm_init.c external/libpng/arm/filter_neon_intrinsics.c external/libpng/arm/filter_neon.S ) set_source_files_properties(external/libpng/arm/filter_neon.S PROPERTIES COMPILE_FLAGS "${CMAKE_ASM_FLAGS} ${CMAKE_CXX_FLAGS} -x assembler-with-cpp") endif() endif() set(source_files ${source_files} image_loader/png_loader.cpp image_saver/save_png.cpp ) endif() if (DLIB_JPEG_SUPPORT) include(cmake_utils/find_libjpeg.cmake) if (JPEG_FOUND) list (APPEND dlib_needed_private_includes ${JPEG_INCLUDE_DIR}) list (APPEND dlib_needed_private_libraries ${JPEG_LIBRARY}) else() # If we can't find libjpeg then statically compile it in. add_definitions(-DDLIB_JPEG_STATIC) set(source_files ${source_files} external/libjpeg/jaricom.c external/libjpeg/jcapimin.c external/libjpeg/jcapistd.c external/libjpeg/jcarith.c external/libjpeg/jccoefct.c external/libjpeg/jccolor.c external/libjpeg/jcdctmgr.c external/libjpeg/jchuff.c external/libjpeg/jcinit.c external/libjpeg/jcmainct.c external/libjpeg/jcmarker.c external/libjpeg/jcmaster.c external/libjpeg/jcomapi.c external/libjpeg/jcparam.c external/libjpeg/jcprepct.c external/libjpeg/jcsample.c external/libjpeg/jdapimin.c external/libjpeg/jdapistd.c external/libjpeg/jdarith.c external/libjpeg/jdatadst.c external/libjpeg/jdatasrc.c external/libjpeg/jdcoefct.c external/libjpeg/jdcolor.c external/libjpeg/jddctmgr.c external/libjpeg/jdhuff.c external/libjpeg/jdinput.c external/libjpeg/jdmainct.c external/libjpeg/jdmarker.c external/libjpeg/jdmaster.c external/libjpeg/jdmerge.c external/libjpeg/jdpostct.c external/libjpeg/jdsample.c external/libjpeg/jerror.c external/libjpeg/jfdctflt.c external/libjpeg/jfdctfst.c external/libjpeg/jfdctint.c external/libjpeg/jidctflt.c external/libjpeg/jidctfst.c external/libjpeg/jidctint.c external/libjpeg/jmemmgr.c external/libjpeg/jmemnobs.c external/libjpeg/jquant1.c external/libjpeg/jquant2.c external/libjpeg/jutils.c ) endif() set(source_files ${source_files} image_loader/jpeg_loader.cpp image_saver/save_jpeg.cpp ) endif() if (DLIB_WEBP_SUPPORT) include(cmake_utils/find_libwebp.cmake) if (WEBP_FOUND) list (APPEND dlib_needed_private_includes ${WEBP_INCLUDE_DIR}) list (APPEND dlib_needed_private_libraries ${WEBP_LIBRARY}) set(source_files ${source_files} image_loader/webp_loader.cpp image_saver/save_webp.cpp ) else() set(DLIB_WEBP_SUPPORT OFF CACHE BOOL ${DLIB_WEBP_SUPPORT_STR} FORCE ) toggle_preprocessor_switch(DLIB_WEBP_SUPPORT) endif() endif() if (DLIB_JXL_SUPPORT) include(cmake_utils/find_libjxl.cmake) if (JXL_FOUND) list (APPEND dlib_needed_private_includes ${JXL_INCLUDE_DIRS}) list (APPEND dlib_needed_private_libraries ${JXL_LIBRARIES}) list (APPEND dlib_needed_public_cflags ${JXL_CFLAGS}) list (APPEND dlib_needed_public_ldflags ${JXL_LDFLAGS}) set(source_files ${source_files} image_loader/jxl_loader.cpp image_saver/save_jxl.cpp ) enable_preprocessor_switch(DLIB_JXL_SUPPORT) else() set(DLIB_JXL_SUPPORT OFF CACHE BOOL ${DLIB_JXL_SUPPORT_STR} FORCE) disable_preprocessor_switch(DLIB_JXL_SUPPORT) endif() endif() if (DLIB_USE_BLAS OR DLIB_USE_LAPACK OR DLIB_USE_MKL_FFT) if (DLIB_USE_MKL_WITH_TBB AND DLIB_USE_MKL_SEQUENTIAL) set(DLIB_USE_MKL_SEQUENTIAL OFF CACHE STRING ${DLIB_USE_MKL_SEQUENTIAL_STR} FORCE ) toggle_preprocessor_switch(DLIB_USE_MKL_SEQUENTIAL) message(STATUS "Disabling DLIB_USE_MKL_SEQUENTIAL. It cannot be used simultaneously with DLIB_USE_MKL_WITH_TBB.") endif() # Try to find BLAS, LAPACK and MKL include(cmake_utils/find_blas.cmake) if (DLIB_USE_BLAS) if (blas_found) list (APPEND dlib_needed_public_libraries ${blas_libraries}) else() set(DLIB_USE_BLAS OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE ) toggle_preprocessor_switch(DLIB_USE_BLAS) endif() endif() if (DLIB_USE_LAPACK) if (lapack_found) list (APPEND dlib_needed_public_libraries ${lapack_libraries}) if (lapack_with_underscore) set(LAPACK_FORCE_UNDERSCORE 1) enable_preprocessor_switch(LAPACK_FORCE_UNDERSCORE) elseif (lapack_without_underscore) set(LAPACK_FORCE_NOUNDERSCORE 1) enable_preprocessor_switch(LAPACK_FORCE_NOUNDERSCORE) endif () else() set(DLIB_USE_LAPACK OFF CACHE STRING ${DLIB_USE_LAPACK_STR} FORCE ) toggle_preprocessor_switch(DLIB_USE_LAPACK) endif() endif() if (DLIB_USE_MKL_FFT) if (found_intel_mkl AND found_intel_mkl_headers) list (APPEND dlib_needed_public_includes ${mkl_include_dir}) list (APPEND dlib_needed_public_libraries ${mkl_libraries}) else() set(DLIB_USE_MKL_FFT OFF CACHE STRING ${DLIB_USE_MKL_FFT_STR} FORCE ) toggle_preprocessor_switch(DLIB_USE_MKL_FFT) endif() endif() endif() if (DLIB_USE_CUDA) find_package(CUDAToolkit) if (CUDAToolkit_FOUND AND CUDAToolkit_NVCC_EXECUTABLE) set(CMAKE_CUDA_COMPILER ${CUDAToolkit_NVCC_EXECUTABLE}) # Set USER_DID_NOT_SPECIFY_WHAT_CUDA_ARCH_TO_USE before calling # enable_language(CUDA) because enable_language() sets # CMAKE_CUDA_ARCHITECTURES to a default that isn't especially # helpful for most users in newer cmake versions. E.g. it picks # the oldest supported arch the cuda toolkit you have can build for # which is often so old your GPU can't actually run the resulting # kernels. set(USER_DID_NOT_SPECIFY_WHAT_CUDA_ARCH_TO_USE FALSE) if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES AND (NOT DEFINED ENV{CUDAARCHS} OR "$ENV{CUDAARCHS}" STREQUAL "")) set(USER_DID_NOT_SPECIFY_WHAT_CUDA_ARCH_TO_USE TRUE) endif() enable_language(CUDA) # If the user didn't say what cuda arch they want to use try to pick something reasonable if(USER_DID_NOT_SPECIFY_WHAT_CUDA_ARCH_TO_USE) if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.24) # Auto-detect host GPU(s); safest default on modern CMake. set(CMAKE_CUDA_ARCHITECTURES native) # requires CMake ≥ 3.24 else() # Fallback by nvcc version to avoid asking for archs it doesn't know yet if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11) # CUDA 10.x and older set(CMAKE_CUDA_ARCHITECTURES 52;60;61;70;75) elseif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12) # CUDA 11.x set(CMAKE_CUDA_ARCHITECTURES 60;61;70;75;80;86) else() # CUDA 12.x+ # (Keep this conservative; add 89/90 only on toolkits that support them.) set(CMAKE_CUDA_ARCHITECTURES 70;75;80;86;89;90) endif() endif() endif() find_package(CUDNN) if(CUDNN_FOUND) set(source_files ${source_files} cuda/cuda_dlib.cu cuda/cudnn_dlibapi.cpp cuda/cublas_dlibapi.cpp cuda/cusolver_dlibapi.cu cuda/curand_dlibapi.cpp cuda/cuda_data_ptr.cpp cuda/gpu_data.cpp ) list (APPEND dlib_needed_private_libraries CUDA::cublas) list (APPEND dlib_needed_private_libraries ${CUDNN_LIBRARY_PATH}) list (APPEND dlib_needed_private_libraries CUDA::curand) list (APPEND dlib_needed_private_libraries CUDA::cusolver) list (APPEND dlib_needed_private_libraries CUDA::cudart) if(openmp_libraries) list (APPEND dlib_needed_private_libraries ${openmp_libraries}) endif() include_directories(${CUDAToolkit_INCLUDE_DIRS} ${CUDNN_INCLUDE_PATH}) message(STATUS "Enabling CUDA support for dlib. DLIB WILL USE CUDA using cuda arch ${CMAKE_CUDA_ARCHITECTURES}. If you don't want to use that arch set the CUDAARCHS env var or CMAKE_CUDA_ARCHITECTURES cmake variable.") else() set(DLIB_USE_CUDA OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE ) toggle_preprocessor_switch(DLIB_USE_CUDA) message(STATUS "DID NOT FIND CUDNN") message(STATUS "Disabling CUDA support for dlib. DLIB WILL NOT USE CUDA") endif() else() set(DLIB_USE_CUDA OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE ) toggle_preprocessor_switch(DLIB_USE_CUDA) if (NOT CUDA_FOUND) message(STATUS "DID NOT FIND CUDA") endif() message(STATUS "Disabling CUDA support for dlib. DLIB WILL NOT USE CUDA") endif() endif() if (DLIB_LINK_WITH_SQLITE3) find_library(sqlite sqlite3) # make sure sqlite3.h is in the include path find_path(sqlite_path sqlite3.h) if (sqlite AND sqlite_path) list (APPEND dlib_needed_public_includes ${sqlite_path}) list (APPEND dlib_needed_public_libraries ${sqlite} ) else() set(DLIB_LINK_WITH_SQLITE3 OFF CACHE STRING ${DLIB_LINK_WITH_SQLITE3_STR} FORCE ) endif() mark_as_advanced(sqlite sqlite_path) endif() if (DLIB_USE_FFTW) find_library(fftw fftw3) # make sure fftw3.h is in the include path find_path(fftw_path fftw3.h) if (fftw AND fftw_path) list (APPEND dlib_needed_private_includes ${fftw_path}) list (APPEND dlib_needed_private_libraries ${fftw}) else() set(DLIB_USE_FFTW OFF CACHE STRING ${DLIB_USE_FFTW_STR} FORCE ) toggle_preprocessor_switch(DLIB_USE_FFTW) endif() mark_as_advanced(fftw fftw_path) endif() if (DLIB_USE_FFMPEG) include(cmake_utils/find_ffmpeg.cmake) if (FFMPEG_FOUND) list (APPEND dlib_needed_public_includes ${FFMPEG_INCLUDE_DIRS}) list (APPEND dlib_needed_public_libraries ${FFMPEG_LINK_LIBRARIES}) list (APPEND dlib_needed_public_cflags ${FFMPEG_CFLAGS}) list (APPEND dlib_needed_public_ldflags ${FFMPEG_LDFLAGS}) enable_preprocessor_switch(DLIB_USE_FFMPEG) else() set(DLIB_USE_FFMPEG OFF CACHE BOOL ${DLIB_USE_FFMPEG_STR} FORCE ) disable_preprocessor_switch(DLIB_USE_FFMPEG) endif() endif() add_library(dlib ${source_files}) endif () ##### end of if NOT DLIB_ISO_CPP_ONLY ########################################################## target_include_directories(dlib INTERFACE $ INTERFACE $ PUBLIC ${dlib_needed_public_includes} PRIVATE ${dlib_needed_private_includes} ) target_link_libraries(dlib PUBLIC ${dlib_needed_public_libraries} ${dlib_needed_public_ldflags}) target_link_libraries(dlib PRIVATE ${dlib_needed_private_libraries}) target_compile_options(dlib PUBLIC ${dlib_needed_public_cflags}) if (DLIB_IN_PROJECT_BUILD) target_compile_options(dlib PUBLIC ${active_preprocessor_switches}) else() # These are private in this case because they will be controlled by the # contents of dlib/config.h once it's installed. But for in project # builds, there is no real config.h so they are public in the above case. target_compile_options(dlib PRIVATE ${active_preprocessor_switches}) # Do this so that dlib/config.h won't set DLIB_NOT_CONFIGURED. This will then allow # the code in dlib/threads_kernel_shared.cpp to emit a linker error for users who # don't use the configured config.h file generated by cmake. target_compile_options(dlib PRIVATE -DDLIB__CMAKE_GENERATED_A_CONFIG_H_FILE) # Do this so that dlib/config.h can record the version of dlib it's configured with # and ultimately issue a linker error to people who try to use a binary dlib that is # the wrong version. set(DLIB_CHECK_FOR_VERSION_MISMATCH DLIB_VERSION_MISMATCH_CHECK__EXPECTED_VERSION_${CPACK_PACKAGE_VERSION_MAJOR}_${CPACK_PACKAGE_VERSION_MINOR}_${CPACK_PACKAGE_VERSION_PATCH}) target_compile_options(dlib PRIVATE "-DDLIB_CHECK_FOR_VERSION_MISMATCH=${DLIB_CHECK_FOR_VERSION_MISMATCH}") endif() # Allow the unit tests to ask us to compile the all/source.cpp file just to make sure it compiles. if (DLIB_TEST_COMPILE_ALL_SOURCE_CPP) add_library(dlib_all_source_cpp STATIC all/source.cpp) target_link_libraries(dlib_all_source_cpp dlib) target_compile_options(dlib_all_source_cpp PUBLIC ${active_preprocessor_switches}) target_compile_features(dlib_all_source_cpp PUBLIC cxx_std_14) endif() target_compile_features(dlib PUBLIC cxx_std_14) if((MSVC AND CMAKE_VERSION VERSION_LESS 3.11)) target_compile_options(dlib PUBLIC ${active_compile_opts}) target_compile_options(dlib PRIVATE ${active_compile_opts_private}) else() target_compile_options(dlib PUBLIC $<$:${active_compile_opts}>) target_compile_options(dlib PRIVATE $<$:${active_compile_opts_private}>) endif() # Install the library if (NOT DLIB_IN_PROJECT_BUILD) string (REPLACE ";" " " pkg_config_dlib_needed_libraries "${dlib_needed_public_libraries}") # Make the -I include options for pkg-config foreach (ITR ${dlib_needed_public_includes}) set (pkg_config_dlib_needed_includes "${pkg_config_dlib_needed_includes} -I${ITR}") endforeach() set_target_properties(dlib PROPERTIES VERSION ${VERSION}) install(TARGETS dlib EXPORT dlib RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} # Windows considers .dll to be runtime artifacts LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib FILES_MATCHING PATTERN "*.h" PATTERN "*.cmake" PATTERN "*_tutorial.txt" PATTERN "cassert" PATTERN "cstring" PATTERN "fstream" PATTERN "iomanip" PATTERN "iosfwd" PATTERN "iostream" PATTERN "istream" PATTERN "locale" PATTERN "ostream" PATTERN "sstream" REGEX "${CMAKE_CURRENT_BINARY_DIR}" EXCLUDE) configure_file(${PROJECT_SOURCE_DIR}/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/config.h) # overwrite config.h with the configured one install(FILES ${CMAKE_CURRENT_BINARY_DIR}/config.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib) configure_file(${PROJECT_SOURCE_DIR}/revision.h.in ${CMAKE_CURRENT_BINARY_DIR}/revision.h) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/revision.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib) ## Config.cmake generation and installation set(ConfigPackageLocation "${CMAKE_INSTALL_LIBDIR}/cmake/dlib") install(EXPORT dlib NAMESPACE dlib:: DESTINATION ${ConfigPackageLocation}) configure_file(cmake_utils/dlibConfig.cmake.in "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfig.cmake" @ONLY) include(CMakePackageConfigHelpers) write_basic_package_version_file( "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfigVersion.cmake" VERSION ${VERSION} COMPATIBILITY AnyNewerVersion ) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfig.cmake" "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfigVersion.cmake" DESTINATION ${ConfigPackageLocation}) ## dlib-1.pc generation and installation configure_file("cmake_utils/dlib.pc.in" "dlib-1.pc" @ONLY) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/dlib-1.pc" DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") # Add a cpack "package" target. This will create an archive containing # the built library file, the header files, and cmake and pkgconfig # configuration files. include(CPack) endif() endif() if (MSVC) # Give the output library files names that are unique functions of the # visual studio mode that compiled them. We do this so that people who # compile dlib and then copy the .lib files around (which they shouldn't be # doing in the first place!) will hopefully be slightly less confused by # what happens since, at the very least, the filenames will indicate what # visual studio runtime they go with. math(EXPR numbits ${CMAKE_SIZEOF_VOID_P}*8) set_target_properties(dlib PROPERTIES DEBUG_POSTFIX "${VERSION}_debug_${numbits}bit_msvc${MSVC_VERSION}") set_target_properties(dlib PROPERTIES RELEASE_POSTFIX "${VERSION}_release_${numbits}bit_msvc${MSVC_VERSION}") set_target_properties(dlib PROPERTIES MINSIZEREL_POSTFIX "${VERSION}_minsizerel_${numbits}bit_msvc${MSVC_VERSION}") set_target_properties(dlib PROPERTIES RELWITHDEBINFO_POSTFIX "${VERSION}_relwithdebinfo_${numbits}bit_msvc${MSVC_VERSION}") endif() # Check if we are being built as part of a pybind11 module. if (COMMAND pybind11_add_module) # Don't export unnecessary symbols. set_target_properties(dlib PROPERTIES CXX_VISIBILITY_PRESET "hidden") set_target_properties(dlib PROPERTIES CUDA_VISIBILITY_PRESET "hidden") endif() if (WIN32 AND mkl_iomp_dll) # If we are using the Intel MKL on windows then try and copy the iomp dll # file to the output folder. We do this since a very large number of # windows users don't understand that they need to add the Intel MKL's # folders to their PATH to use the Intel MKL. They then complain on the # dlib forums. Copying the Intel MKL dlls to the output directory removes # the need to add the Intel MKL to the PATH. if (CMAKE_LIBRARY_OUTPUT_DIRECTORY) add_custom_command(TARGET dlib POST_BUILD # In some newer versions of windows/visual studio the output Config folder doesn't # exist at first, so you can't copy to it unless you make it yourself. So make # sure the target folder exists first. COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/" COMMAND ${CMAKE_COMMAND} -E copy "${mkl_iomp_dll}" "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/" ) else() add_custom_command(TARGET dlib POST_BUILD # In some newer versions of windows/visual studio the output Config folder doesn't # exist at first, so you can't copy to it unless you make it yourself. So make # sure the target folder exists first. COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/$/" COMMAND ${CMAKE_COMMAND} -E copy "${mkl_iomp_dll}" "${CMAKE_BINARY_DIR}/$/" ) endif() endif() add_library(dlib::dlib ALIAS dlib) ================================================ FILE: dlib/LICENSE.txt ================================================ Boost Software License - Version 1.0 - August 17th, 2003 Permission is hereby granted, free of charge, to any person or organization obtaining a copy of the software and accompanying documentation covered by this license (the "Software") to use, reproduce, display, distribute, execute, and transmit the Software, and to prepare derivative works of the Software, and to permit third-parties to whom the Software is furnished to do so, all subject to the following: The copyright notices in the Software and this entire statement, including the above license grant, this restriction and the following disclaimer, must be included in all copies of the Software, in whole or in part, and all derivative works of the Software, unless such copies or derivative works are solely in the form of machine-executable object code generated by a source language processor. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: dlib/algs.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifdef DLIB_ALL_SOURCE_END #include "dlib_basic_cpp_build_tutorial.txt" #endif #ifndef DLIB_ALGs_ #define DLIB_ALGs_ // this file contains miscellaneous stuff // Give people who forget the -std=c++14 option a reminder #if (defined(__GNUC__) && ((__GNUC__ >= 5 && __GNUC_MINOR__ >= 0) || (__GNUC__ > 5))) || \ (defined(__clang__) && ((__clang_major__ >= 3 && __clang_minor__ >= 4) || (__clang_major__ >= 3))) #if __cplusplus < 201402L #error "Dlib requires C++14 support. Give your compiler the -std=c++14 option to enable it." #endif #endif #if defined __NVCC__ // Disable the "statement is unreachable" message since it will go off on code that is // actually reachable but just happens to not be reachable sometimes during certain // template instantiations. #ifdef __NVCC_DIAG_PRAGMA_SUPPORT__ #pragma nv_diag_suppress code_is_unreachable #else #pragma diag_suppress code_is_unreachable #endif #endif #ifdef _MSC_VER #if _MSC_VER < 1900 #error "dlib versions newer than v19.1 use C++11 and therefore require Visual Studio 2015 or newer." #endif // Disable the following warnings for Visual Studio // this is to disable the "'this' : used in base member initializer list" // warning you get from some of the GUI objects since all the objects // require that their parent class be passed into their constructor. // In this case though it is totally safe so it is ok to disable this warning. #pragma warning(disable : 4355) // This is a warning you get sometimes when Visual Studio performs a Koenig Lookup. // This is a bug in visual studio. It is a totally legitimate thing to // expect from a compiler. #pragma warning(disable : 4675) // This is a warning you get from visual studio 2005 about things in the standard C++ // library being "deprecated." I checked the C++ standard and it doesn't say jack // about any of them (I checked the searchable PDF). So this warning is total Bunk. #pragma warning(disable : 4996) // This is a warning you get from visual studio 2003: // warning C4345: behavior change: an object of POD type constructed with an initializer // of the form () will be default-initialized. // I love it when this compiler gives warnings about bugs in previous versions of itself. #pragma warning(disable : 4345) // Disable warnings about conversion from size_t to unsigned long and long. #pragma warning(disable : 4267) // Disable warnings about conversion from double to float #pragma warning(disable : 4244) #pragma warning(disable : 4305) // Disable "warning C4180: qualifier applied to function type has no meaning; ignored". // This warning happens often in generic code that works with functions and isn't useful. #pragma warning(disable : 4180) // Disable "warning C4290: C++ exception specification ignored except to indicate a function is not __declspec(nothrow)" #pragma warning(disable : 4290) // DNN module uses template-based network declaration that leads to very long // type names. Visual Studio will produce Warning C4503 in such cases. https://msdn.microsoft.com/en-us/library/074af4b6.aspx says // that correct binaries are still produced even when this warning happens, but linker errors from visual studio, if they occur could be confusing. #pragma warning( disable: 4503 ) #endif #ifdef __BORLANDC__ // Disable the following warnings for the Borland Compilers // // These warnings just say that the compiler is refusing to inline functions with // loops or try blocks in them. // #pragma option -w-8027 #pragma option -w-8026 #endif #include // for the exceptions #ifdef __CYGWIN__ namespace std { typedef std::basic_string wstring; } #endif #include "platform.h" #include "windows_magic.h" #include // for std::swap #include // for std::bad_alloc #include #include #include #include // for std::isfinite for is_finite() #include "assert.h" #include "error.h" #include "noncopyable.h" #include "enable_if.h" #include "uintn.h" #include "numeric_constants.h" #include "memory_manager_stateless/memory_manager_stateless_kernel_1.h" // for the default memory manager #include "type_traits.h" // ---------------------------------------------------------------------------------------- /*!A _dT !*/ template inline charT _dTcast (const char a, const wchar_t b); template <> inline char _dTcast (const char a, const wchar_t ) { return a; } template <> inline wchar_t _dTcast (const char , const wchar_t b) { return b; } template inline const charT* _dTcast ( const char* a, const wchar_t* b); template <> inline const char* _dTcast ( const char* a, const wchar_t* ) { return a; } template <> inline const wchar_t* _dTcast ( const char* , const wchar_t* b) { return b; } #define _dT(charT,str) _dTcast(str,L##str) /*! requires - charT == char or wchar_t - str == a string or character literal ensures - returns the literal in the form of a charT type literal. !*/ // ---------------------------------------------------------------------------------------- namespace dlib { // ---------------------------------------------------------------------------------------- /*!A default_memory_manager This memory manager just calls new and delete directly. !*/ typedef memory_manager_stateless_kernel_1 default_memory_manager; // ---------------------------------------------------------------------------------------- /*!A swap !*/ // make swap available in the dlib namespace using std::swap; // ---------------------------------------------------------------------------------------- /*! Here is where I define my return codes. It is important that they all be < 0. !*/ enum general_return_codes { TIMEOUT = -1, WOULDBLOCK = -2, OTHER_ERROR = -3, SHUTDOWN = -4, PORTINUSE = -5 }; // ---------------------------------------------------------------------------------------- inline unsigned long square_root ( unsigned long value ) /*! requires - value <= 2^32 - 1 ensures - returns the square root of value. if the square root is not an integer then it will be rounded up to the nearest integer. !*/ { unsigned long x; // set the initial guess for what the root is depending on // how big value is if (value < 3) return value; else if (value < 4096) // 12 x = 45; else if (value < 65536) // 16 x = 179; else if (value < 1048576) // 20 x = 717; else if (value < 16777216) // 24 x = 2867; else if (value < 268435456) // 28 x = 11469; else // 32 x = 45875; // find the root x = (x + value/x)>>1; x = (x + value/x)>>1; x = (x + value/x)>>1; x = (x + value/x)>>1; if (x*x < value) return x+1; else return x; } // ---------------------------------------------------------------------------------------- template < typename T > void median ( T& one, T& two, T& three ); /*! requires - T implements operator< - T is swappable by a global swap() ensures - #one is the median - #one, #two, and #three is some permutation of one, two, and three. !*/ template < typename T > void median ( T& one, T& two, T& three ) { using std::swap; using dlib::swap; if ( one < two ) { // one < two if ( two < three ) { // one < two < three : two swap(one,two); } else { // one < two >= three if ( one < three) { // three swap(three,one); } } } else { // one >= two if ( three < one ) { // three <= one >= two if ( three < two ) { // two swap(two,one); } else { // three swap(three,one); } } } } // ---------------------------------------------------------------------------------------- namespace relational_operators { template < typename A, typename B > constexpr bool operator> ( const A& a, const B& b ) { return b < a; } // --------------------------------- template < typename A, typename B > constexpr bool operator!= ( const A& a, const B& b ) { return !(a == b); } // --------------------------------- template < typename A, typename B > constexpr bool operator<= ( const A& a, const B& b ) { return !(b < a); } // --------------------------------- template < typename A, typename B > constexpr bool operator>= ( const A& a, const B& b ) { return !(a < b); } } // ---------------------------------------------------------------------------------------- template < typename T > void exchange ( T& a, T& b ) /*! This function does the exact same thing that global swap does and it does it by just calling swap. But a lot of compilers have problems doing a Koenig Lookup and the fact that this has a different name (global swap has the same name as the member functions called swap) makes them compile right. So this is a workaround but not too ugly of one. But hopefully I can get rid of this in a few years. So this function is already deprecated. This also means you should NOT use this function in your own code unless you have to support an old buggy compiler that benefits from this hack. !*/ { using std::swap; using dlib::swap; swap(a,b); } // ---------------------------------------------------------------------------------------- struct general_ {}; struct special_ : general_ {}; template struct int_ { typedef int type; }; // ---------------------------------------------------------------------------------------- /*!A is_same_object This is a templated function which checks if both of its arguments are actually references to the same object. It returns true if they are and false otherwise. !*/ // handle the case where T and U are unrelated types. template < typename T, typename U > std::enable_if_t::value && !std::is_convertible::value, bool> is_same_object ( const T& a, const U& b ) { return ((void*)&a == (void*)&b); } // handle the case where T and U are related types because their pointers can be // implicitly converted into one or the other. E.g. a derived class and its base class. // Or where both T and U are just the same type. This way we make sure that if there is a // valid way to convert between these two pointer types then we will take that route rather // than the void* approach used otherwise. template < typename T, typename U > std::enable_if_t::value || std::is_convertible::value, bool> is_same_object ( const T& a, const U& b ) { return (&a == &b); } // ---------------------------------------------------------------------------------------- template < typename T > class copy_functor { public: void operator() ( const T& source, T& destination ) const { destination = source; } }; // ---------------------------------------------------------------------------------------- /*!A static_switch To use this template you give it some number of boolean expressions and it tells you which one of them is true. If more than one of them is true then it causes a compile time error. for example: static_switch<1 + 1 == 2, 4 - 1 == 4>::value == 1 // because the first expression is true static_switch<1 + 1 == 3, 4 == 4>::value == 2 // because the second expression is true static_switch<1 + 1 == 3, 4 == 5>::value == 0 // 0 here because none of them are true static_switch<1 + 1 == 2, 4 == 4>::value == compiler error // because more than one expression is true !*/ template < bool v1 = 0, bool v2 = 0, bool v3 = 0, bool v4 = 0, bool v5 = 0, bool v6 = 0, bool v7 = 0, bool v8 = 0, bool v9 = 0, bool v10 = 0, bool v11 = 0, bool v12 = 0, bool v13 = 0, bool v14 = 0, bool v15 = 0 > struct static_switch; template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 0; }; template <> struct static_switch<1,0,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 1; }; template <> struct static_switch<0,1,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 2; }; template <> struct static_switch<0,0,1,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 3; }; template <> struct static_switch<0,0,0,1,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 4; }; template <> struct static_switch<0,0,0,0,1,0,0,0,0,0,0,0,0,0,0> { const static int value = 5; }; template <> struct static_switch<0,0,0,0,0,1,0,0,0,0,0,0,0,0,0> { const static int value = 6; }; template <> struct static_switch<0,0,0,0,0,0,1,0,0,0,0,0,0,0,0> { const static int value = 7; }; template <> struct static_switch<0,0,0,0,0,0,0,1,0,0,0,0,0,0,0> { const static int value = 8; }; template <> struct static_switch<0,0,0,0,0,0,0,0,1,0,0,0,0,0,0> { const static int value = 9; }; template <> struct static_switch<0,0,0,0,0,0,0,0,0,1,0,0,0,0,0> { const static int value = 10; }; template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,1,0,0,0,0> { const static int value = 11; }; template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,1,0,0,0> { const static int value = 12; }; template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,1,0,0> { const static int value = 13; }; template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,1,0> { const static int value = 14; }; template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,0,1> { const static int value = 15; }; // ---------------------------------------------------------------------------------------- template std::enable_if_t::value, bool> is_finite(T value) /*! requires - value must be some kind of scalar type such as int or double ensures - returns true if value is a finite value (e.g. not infinity or NaN) and false otherwise. !*/ { return std::isfinite(value); } template std::enable_if_t::value, bool> is_finite(T value) { return std::isfinite((double)value); } // ---------------------------------------------------------------------------------------- /*!A promote This is a template that takes one of the built in scalar types and gives you another scalar type that should be big enough to hold sums of values from the original scalar type. The new scalar type will also always be signed. For example, promote::type == int32 !*/ template struct promote; template struct promote { typedef int32 type; }; template struct promote { typedef int32 type; }; template struct promote { typedef int64 type; }; template struct promote { typedef int64 type; }; template <> struct promote { typedef double type; }; template <> struct promote { typedef double type; }; template <> struct promote { typedef long double type; }; // ---------------------------------------------------------------------------------------- /*!A assign_zero_if_built_in_scalar_type This function assigns its argument the value of 0 if it is a built in scalar type according to the is_built_in_scalar_type<> template. If it isn't a built in scalar type then it does nothing. !*/ template inline typename disable_if,void>::type assign_zero_if_built_in_scalar_type (T&){} template inline typename enable_if,void>::type assign_zero_if_built_in_scalar_type (T& a){a=0;} // ---------------------------------------------------------------------------------------- template T put_in_range ( const T& a, const T& b, const T& val ) /*! requires - T is a type that looks like double, float, int, or so forth ensures - if (val is within the range [a,b]) then - returns val - else - returns the end of the range [a,b] that is closest to val !*/ { if (a < b) { if (val < a) return a; else if (val > b) return b; } else { if (val < b) return b; else if (val > a) return a; } return val; } // overload for double inline double put_in_range(const double& a, const double& b, const double& val) { return put_in_range(a,b,val); } // ---------------------------------------------------------------------------------------- /*!A tabs This is a template to compute the absolute value a number at compile time. For example, abs<-4>::value == 4 abs<4>::value == 4 !*/ template struct tabs { const static long value = x; }; template struct tabs::type> { const static long value = -x; }; // ---------------------------------------------------------------------------------------- /*!A tmax This is a template to compute the max of two values at compile time For example, abs<4,7>::value == 7 !*/ template struct tmax { const static long value = x; }; template struct tmax x)>::type> { const static long value = y; }; // ---------------------------------------------------------------------------------------- /*!A tmin This is a template to compute the min of two values at compile time For example, abs<4,7>::value == 4 !*/ template struct tmin { const static long value = x; }; template struct tmin::type> { const static long value = y; }; // ---------------------------------------------------------------------------------------- #define DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(testname, returnT, funct_name, args) \ struct _two_bytes_##testname { char a[2]; }; \ template < typename T, returnT (T::*funct)args > \ struct _helper_##testname { typedef char type; }; \ template \ static char _has_##testname##_helper( typename _helper_##testname::type ) { return 0;} \ template \ static _two_bytes_##testname _has_##testname##_helper(int) { return _two_bytes_##testname();} \ template struct _##testname##workaroundbug { \ const static unsigned long U = sizeof(_has_##testname##_helper('a')); }; \ template ::U > \ struct testname { static const bool value = false; }; \ template \ struct testname { static const bool value = true; }; /*!A DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST The DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST() macro is used to define traits templates that tell you if a class has a certain member function. For example, to make a test to see if a class has a public method with the signature void print(int) you would say: DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, (int)) Then you can check if a class, T, has this method by looking at the boolean value: has_print::value which will be true if the member function is in the T class. Note that you can test for member functions taking no arguments by simply passing in empty () like so: DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, ()) This would test for a member of the form: void print(). To test for const member functions you would use a statement such as this: DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, ()const) This would test for a member of the form: void print() const. To test for const templated member functions you would use a statement such as this: DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, template print, ()) This would test for a member of the form: template void print(). !*/ // ---------------------------------------------------------------------------------------- template class funct_wrap0 { public: funct_wrap0(T (&f_)()):f(f_){} T operator()() const { return f(); } private: T (&f)(); }; template class funct_wrap1 { public: funct_wrap1(T (&f_)(A0)):f(f_){} T operator()(A0 a0) const { return f(a0); } private: T (&f)(A0); }; template class funct_wrap2 { public: funct_wrap2(T (&f_)(A0,A1)):f(f_){} T operator()(A0 a0, A1 a1) const { return f(a0,a1); } private: T (&f)(A0,A1); }; template class funct_wrap3 { public: funct_wrap3(T (&f_)(A0,A1,A2)):f(f_){} T operator()(A0 a0, A1 a1, A2 a2) const { return f(a0,a1,a2); } private: T (&f)(A0,A1,A2); }; template class funct_wrap4 { public: funct_wrap4(T (&f_)(A0,A1,A2,A3)):f(f_){} T operator()(A0 a0, A1 a1, A2 a2, A3 a3) const { return f(a0,a1,a2,a3); } private: T (&f)(A0,A1,A2,A3); }; template class funct_wrap5 { public: funct_wrap5(T (&f_)(A0,A1,A2,A3,A4)):f(f_){} T operator()(A0 a0, A1 a1, A2 a2, A3 a3, A4 a4) const { return f(a0,a1,a2,a3,a4); } private: T (&f)(A0,A1,A2,A3,A4); }; /*!A wrap_function This is a template that allows you to turn a global function into a function object. The reason for this template's existence is so you can do stuff like this: template void call_funct(const T& funct) { cout << funct(); } std::string test() { return "asdfasf"; } int main() { call_funct(wrap_function(test)); } The above code doesn't work right on some compilers if you don't use wrap_function. !*/ template funct_wrap0 wrap_function(T (&f)()) { return funct_wrap0(f); } template funct_wrap1 wrap_function(T (&f)(A0)) { return funct_wrap1(f); } template funct_wrap2 wrap_function(T (&f)(A0, A1)) { return funct_wrap2(f); } template funct_wrap3 wrap_function(T (&f)(A0, A1, A2)) { return funct_wrap3(f); } template funct_wrap4 wrap_function(T (&f)(A0, A1, A2, A3)) { return funct_wrap4(f); } template funct_wrap5 wrap_function(T (&f)(A0, A1, A2, A3, A4)) { return funct_wrap5(f); } // ---------------------------------------------------------------------------------------- template class stack_based_memory_block : noncopyable { /*! WHAT THIS OBJECT REPRESENTS This object is a simple container for a block of memory of bSIZE bytes. This memory block is located on the stack and properly aligned to hold any kind of object. !*/ public: static const unsigned long size = bSIZE; stack_based_memory_block(): data(mem.data) {} void* get () { return data; } /*! ensures - returns a pointer to the block of memory contained in this object !*/ const void* get () const { return data; } /*! ensures - returns a pointer to the block of memory contained in this object !*/ private: // You obviously can't have a block of memory that has zero bytes in it. COMPILE_TIME_ASSERT(bSIZE > 0); union mem_block { // All of this garbage is to make sure this union is properly aligned // (a union is always aligned such that everything in it would be properly // aligned. So the assumption here is that one of these objects has // a large enough alignment requirement to satisfy any object this // block of memory might be cast into). void* void_ptr; int integer; struct { void (stack_based_memory_block::*callback)(); stack_based_memory_block* o; } stuff; long double more_stuff; uint64 var1; uint32 var2; double var3; char data[size]; } mem; // The reason for having this variable is that doing it this way avoids // warnings from gcc about violations of strict-aliasing rules. void* const data; }; // ---------------------------------------------------------------------------------------- template < typename T, typename F > auto max_scoring_element( const T& container, F score_func ) -> decltype(std::make_pair(*container.begin(), 0.0)) /*! requires - container has .begin() and .end(), allowing it to be enumerated. - score_func() is a function that takes an element of the container and returns a double. ensures - This function finds the element of container that has the largest score, according to score_func(), and returns a std::pair containing that maximal element along with the score. - If the container is empty then make_pair(a default initialized object, -infinity) is returned. !*/ { double best_score = -std::numeric_limits::infinity(); auto best_i = container.begin(); for (auto i = container.begin(); i != container.end(); ++i) { auto score = score_func(*i); if (score > best_score) { best_score = score; best_i = i; } } using item_type = typename std::remove_reference::type; if (best_i == container.end()) return std::make_pair(item_type(), best_score); else return std::make_pair(*best_i, best_score); } // ---------------------------------------------------------------------------------------- template < typename T, typename F > auto min_scoring_element( const T& container, F score_func ) -> decltype(std::make_pair(*container.begin(), 0.0)) /*! requires - container has .begin() and .end(), allowing it to be enumerated. - score_func() is a function that takes an element of the container and returns a double. ensures - This function finds the element of container that has the smallest score, according to score_func(), and returns a std::pair containing that minimal element along with the score. - If the container is empty then make_pair(a default initialized object, infinity) is returned. !*/ { double best_score = std::numeric_limits::infinity(); auto best_i = container.begin(); for (auto i = container.begin(); i != container.end(); ++i) { auto score = score_func(*i); if (score < best_score) { best_score = score; best_i = i; } } using item_type = typename std::remove_reference::type; if (best_i == container.end()) return std::make_pair(item_type(), best_score); else return std::make_pair(*best_i, best_score); } // ---------------------------------------------------------------------------------------- namespace detail { template constexpr void for_each_impl(Tuple&& t, F&& f, std::index_sequence) { #ifdef __cpp_fold_expressions (std::forward(f)(std::get(std::forward(t))),...); #else (void)std::initializer_list{(std::forward(f)(std::get(std::forward(t))),0)...}; #endif } } template constexpr void for_each_in_tuple(Tuple&& t, F&& f) { detail::for_each_impl(std::forward(t), std::forward(f), std::make_index_sequence>::value>{}); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_ALGs_ ================================================ FILE: dlib/all/source.cpp ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ALL_SOURCe_ #define DLIB_ALL_SOURCe_ #if defined(DLIB_ALGs_) || defined(DLIB_PLATFORm_) #include "../dlib_basic_cpp_build_tutorial.txt" #endif // ISO C++ code #include "../base64/base64_kernel_1.cpp" #include "../bigint/bigint_kernel_1.cpp" #include "../bigint/bigint_kernel_2.cpp" #include "../bit_stream/bit_stream_kernel_1.cpp" #include "../entropy_decoder/entropy_decoder_kernel_1.cpp" #include "../entropy_decoder/entropy_decoder_kernel_2.cpp" #include "../entropy_encoder/entropy_encoder_kernel_1.cpp" #include "../entropy_encoder/entropy_encoder_kernel_2.cpp" #include "../md5/md5_kernel_1.cpp" #include "../tokenizer/tokenizer_kernel_1.cpp" #include "../unicode/unicode.cpp" #include "../test_for_odr_violations.cpp" #ifndef DLIB_ISO_CPP_ONLY // Code that depends on OS specific APIs // include this first so that it can disable the older version // of the winsock API when compiled in windows. #include "../sockets/sockets_kernel_1.cpp" #include "../bsp/bsp.cpp" #include "../dir_nav/dir_nav_kernel_1.cpp" #include "../dir_nav/dir_nav_kernel_2.cpp" #include "../dir_nav/dir_nav_extensions.cpp" #include "../fft/fft.cpp" #include "../linker/linker_kernel_1.cpp" #include "../logger/extra_logger_headers.cpp" #include "../logger/logger_kernel_1.cpp" #include "../logger/logger_config_file.cpp" #include "../misc_api/misc_api_kernel_1.cpp" #include "../misc_api/misc_api_kernel_2.cpp" #include "../sockets/sockets_extensions.cpp" #include "../sockets/sockets_kernel_2.cpp" #include "../sockstreambuf/sockstreambuf.cpp" #include "../sockstreambuf/sockstreambuf_unbuffered.cpp" #include "../server/server_kernel.cpp" #include "../server/server_iostream.cpp" #include "../server/server_http.cpp" #include "../threads/multithreaded_object_extension.cpp" #include "../threads/threaded_object_extension.cpp" #include "../threads/threads_kernel_1.cpp" #include "../threads/threads_kernel_2.cpp" #include "../threads/threads_kernel_shared.cpp" #include "../threads/thread_pool_extension.cpp" #include "../threads/async.cpp" #include "../timer/timer.cpp" #include "../stack_trace.cpp" #ifdef DLIB_PNG_SUPPORT #include "../image_loader/png_loader.cpp" #include "../image_saver/save_png.cpp" #endif #ifdef DLIB_JPEG_SUPPORT #include "../image_loader/jpeg_loader.cpp" #include "../image_saver/save_jpeg.cpp" #endif #include "../gui_widgets/fonts.cpp" #ifndef DLIB_NO_GUI_SUPPORT #include "../gui_widgets/widgets.cpp" #include "../gui_widgets/drawable.cpp" #include "../gui_widgets/canvas_drawing.cpp" #include "../gui_widgets/style.cpp" #include "../gui_widgets/base_widgets.cpp" #include "../gui_core/gui_core_kernel_1.cpp" #include "../gui_core/gui_core_kernel_2.cpp" #endif // DLIB_NO_GUI_SUPPORT #include "../cuda/cpu_dlib.cpp" #include "../cuda/tensor_tools.cpp" #include "../data_io/image_dataset_metadata.cpp" #include "../data_io/mnist.cpp" #include "../data_io/cifar.cpp" #include "../svm/auto.cpp" #include "../global_optimization/global_function_search.cpp" #include "../filtering/kalman_filter.cpp" #endif // DLIB_ISO_CPP_ONLY #define DLIB_ALL_SOURCE_END #endif // DLIB_ALL_SOURCe_ ================================================ FILE: dlib/any/any.h ================================================ // Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_AnY_H_ #define DLIB_AnY_H_ #include "any_abstract.h" #include #include "storage.h" namespace dlib { // ---------------------------------------------------------------------------------------- class any { public: any() = default; any(const any& other) = default; any& operator=(const any& other) = default; any(any&& other) = default; any& operator=(any&& other) = default; template< typename T, std::enable_if_t, any>::value, bool> = true > any(T&& item) : storage{std::forward(item)} { } template< typename T, typename T_ = std::decay_t, std::enable_if_t::value, bool> = true > any& operator=(T&& item) { if (contains()) storage.unsafe_get() = std::forward(item); else *this = any{std::forward(item)}; return *this; } bool is_empty() const { return storage.is_empty(); } void clear() { storage.clear(); } void swap (any& item) { std::swap(*this, item); } template bool contains() const { return storage.contains();} template T& cast_to() { return storage.cast_to(); } template const T& cast_to() const { return storage.cast_to(); } template T& get() { return storage.get(); } private: te::storage_heap storage; }; // ---------------------------------------------------------------------------------------- template T& any_cast(any& a) { return a.cast_to(); } template const T& any_cast(const any& a) { return a.cast_to(); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_AnY_H_ ================================================ FILE: dlib/any/any_abstract.h ================================================ // Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_AnY_ABSTRACT_H_ #ifdef DLIB_AnY_ABSTRACT_H_ #include namespace dlib { // ---------------------------------------------------------------------------------------- class bad_any_cast : public std::bad_cast { /*! WHAT THIS OBJECT REPRESENTS This object is the exception class used by the any object. It is used to indicate when someone attempts to cast an any object into a type which isn't contained in the any object. !*/ public: virtual const char* what() const noexcept { return "bad_any_cast"; } }; // ---------------------------------------------------------------------------------------- class any { /*! INITIAL VALUE - is_empty() == true - for all T: contains() == false WHAT THIS OBJECT REPRESENTS This object is basically a type-safe version of a void*. In particular, it is a container which can contain only one object but the object may be of any type. It is somewhat like the type_safe_union except you don't have to declare the set of possible content types beforehand. So in some sense this is like a less type-strict version of the type_safe_union. !*/ public: any( ); /*! ensures - this object is properly initialized !*/ any ( const any& item ); /*! ensures - copies the state of item into *this. - Note that *this and item will contain independent copies of the contents of item. That is, this function performs a deep copy and therefore does not result in *this containing any kind of reference to item. !*/ any_function ( any_function&& item ); /*! ensures - #item.is_empty() == true - moves item into *this. !*/ template < typename T > any ( const T& item ); /*! ensures - #contains() == true - #cast_to() == item (i.e. a copy of item will be stored in *this) !*/ void clear ( ); /*! ensures - #*this will have its default value. I.e. #is_empty() == true !*/ template bool contains ( ) const; /*! ensures - if (this object currently contains an object of type T) then - returns true - else - returns false !*/ bool is_empty( ) const; /*! ensures - if (this object contains any kind of object) then - returns false - else - returns true !*/ template T& cast_to( ); /*! ensures - if (contains() == true) then - returns a non-const reference to the object contained within *this - else - throws bad_any_cast !*/ template const T& cast_to( ) const; /*! ensures - if (contains() == true) then - returns a const reference to the object contained within *this - else - throws bad_any_cast !*/ template T& get( ); /*! ensures - #is_empty() == false - #contains() == true - if (contains() == true) - returns a non-const reference to the object contained in *this. - else - Constructs an object of type T inside *this - Any previous object stored in this any object is destructed and its state is lost. - returns a non-const reference to the newly created T object. !*/ any& operator= ( const any& item ); /*! ensures - copies the state of item into *this. - Note that *this and item will contain independent copies of the contents of item. That is, this function performs a deep copy and therefore does not result in *this containing any kind of reference to item. !*/ void swap ( any& item ); /*! ensures - swaps *this and item - does not invalidate pointers or references to the object contained inside *this or item. Moreover, a pointer or reference to the object in *this will now refer to the contents of #item and vice versa. !*/ }; // ---------------------------------------------------------------------------------------- template < typename T > T& any_cast( any& a ) { return a.cast_to(); } /*! ensures - returns a.cast_to() !*/ // ---------------------------------------------------------------------------------------- template < typename T > const T& any_cast( const any& a ) { return a.cast_to(); } /*! ensures - returns a.cast_to() !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_AnY_ABSTRACT_H_ ================================================ FILE: dlib/any/any_decision_function.h ================================================ // Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_AnY_DECISION_FUNCTION_Hh_ #define DLIB_AnY_DECISION_FUNCTION_Hh_ #include "any_decision_function_abstract.h" #include "any_function.h" #include "../algs.h" namespace dlib { // ---------------------------------------------------------------------------------------- template using any_decision_function = any_function; // ---------------------------------------------------------------------------------------- } #endif // DLIB_AnY_DECISION_FUNCTION_Hh_ ================================================ FILE: dlib/any/any_decision_function_abstract.h ================================================ // Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_ #ifdef DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_ #include "any_function_abstract.h" #include "../algs.h" namespace dlib { // ---------------------------------------------------------------------------------------- template using any_decision_function = any_function; // ---------------------------------------------------------------------------------------- } #endif // DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_ ================================================ FILE: dlib/any/any_function.h ================================================ // Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_AnY_FUNCTION_Hh_ #define DLIB_AnY_FUNCTION_Hh_ #include "../assert.h" #include "../functional.h" #include "any.h" #include "any_function_abstract.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < class Storage, class F > class any_function_basic; template < class Storage, class R, class... Args > class any_function_basic { private: template using is_valid = std::enable_if_t, any_function_basic>::value && dlib::is_invocable_r::value, bool>; template static auto make_invoker() { return [](void* self, Args... args) -> R { return dlib::invoke(*reinterpret_cast>(self), std::forward(args)...); }; } Storage str; R (*func)(void*, Args...) = nullptr; public: using result_type = R; constexpr any_function_basic(std::nullptr_t) noexcept {} constexpr any_function_basic() = default; constexpr any_function_basic(const any_function_basic& other) = default; constexpr any_function_basic& operator=(const any_function_basic& other) = default; constexpr any_function_basic(any_function_basic&& other) : str{std::move(other.str)}, func{std::exchange(other.func, nullptr)} { } constexpr any_function_basic& operator=(any_function_basic&& other) { if (this != &other) { str = std::move(other.str); func = std::exchange(other.func, nullptr); } return *this; } template = true> any_function_basic( F&& f ) : str{std::forward(f)}, func{make_invoker()} { } template = true> any_function_basic( F* f ) : str{f}, func{make_invoker()} { } R operator()(Args... args) const { return func(const_cast(str.get_ptr()), std::forward(args)...); } void clear() { str.clear(); } void swap (any_function_basic& item) { std::swap(*this, item); } bool is_empty() const noexcept { return str.is_empty() || func == nullptr; } bool is_set() const noexcept { return !is_empty(); } explicit operator bool() const noexcept { return is_set(); } template bool contains() const { return str.template contains();} template T& cast_to() { return str.template cast_to(); } template const T& cast_to() const { return str.template cast_to(); } template T& get() { return str.template get(); } }; // ---------------------------------------------------------------------------------------- template T& any_cast(any_function_basic& a) { return a.template cast_to(); } template const T& any_cast(const any_function_basic& a) { return a.template cast_to(); } // ---------------------------------------------------------------------------------------- template using any_function = any_function_basic, F>; template using any_function_view = any_function_basic; // ---------------------------------------------------------------------------------------- } #endif // DLIB_AnY_FUNCTION_Hh_ ================================================ FILE: dlib/any/any_function_abstract.h ================================================ // Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_AnY_FUNCTION_ABSTRACT_H_ #ifdef DLIB_AnY_FUNCTION_ABSTRACT_H_ #include "any_abstract.h" #include "../algs.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < class Storage, typename function_type > class any_function_basic { /*! REQUIREMENTS ON Storage This must be one of the storage types from dlib/any/storage.hh E.g. storage_heap, storage_stack, etc. It determines the method by which any_function_basic holds onto the function it uses. REQUIREMENTS ON function_type This type should be a function signature. Some examples are: void (int,int) // a function returning nothing and taking two ints void () // a function returning nothing and taking no arguments char (double&) // a function returning a char and taking a reference to a double The number of arguments in the function must be no greater than 10. INITIAL VALUE - is_empty() == true - for all T: contains() == false WHAT THIS OBJECT REPRESENTS This object is a version of dlib::any that is restricted to containing elements which are some kind of function object with an operator() which matches the function signature defined by function_type. Here is an example: #include #include #include "dlib/any.h" void print_message(string str) { cout << str << endl; } int main() { dlib::any_function f; f = print_message; f("hello world"); // calls print_message("hello world") } Note that any_function_basic objects can be used to store general function objects (i.e. defined by a class with an overloaded operator()) in addition to regular global functions. !*/ public: // This is the type of object returned by function_type functions. typedef result_type_for_function_type result_type; any_function_basic( ); /*! ensures - this object is properly initialized !*/ any_function_basic ( const any_function_basic& item ); /*! ensures - copies the state of item into *this. - Note that *this and item will contain independent copies of the contents of item. That is, this function performs a deep copy and therefore does not result in *this containing any kind of reference to item. !*/ any_function_basic ( any_function_basic&& item ); /*! ensures - moves item into *this. - The exact move semantics are determined by which Storage type is used. E.g. storage_heap will result in #item.is_empty()==true but storage_view would result in #item.is_empty() == false !*/ template < typename Funct > any_function_basic ( Funct&& funct ); /*! ensures - #contains() == true - #cast_to() == item (i.e. calling operator() will invoke funct()) !*/ void clear ( ); /*! ensures - #*this will have its default value. I.e. #is_empty() == true !*/ template bool contains ( ) const; /*! ensures - if (this object currently contains an object of type T) then - returns true - else - returns false !*/ bool is_empty( ) const; /*! ensures - if (this object contains any kind of object) then - returns false - else - returns true !*/ bool is_set ( ) const; /*! ensures - returns !is_empty() !*/ explicit operator bool( ) const; /*! ensures - returns is_set() !*/ result_type operator(Args... args) ( ) const; /*! requires - is_empty() == false - the signature defined by function_type takes no arguments ensures - Let F denote the function object contained within *this. Then this function performs: return F(std::forward(args)...) !*/ template T& cast_to( ); /*! ensures - if (contains() == true) then - returns a non-const reference to the object contained within *this - else - throws bad_any_cast !*/ template const T& cast_to( ) const; /*! ensures - if (contains() == true) then - returns a const reference to the object contained within *this - else - throws bad_any_cast !*/ template T& get( ); /*! ensures - #is_empty() == false - #contains() == true - if (contains() == true) - returns a non-const reference to the object contained in *this. - else - Constructs an object of type T inside *this - Any previous object stored in this any_function_basic object is destructed and its state is lost. - returns a non-const reference to the newly created T object. !*/ any_function_basic& operator= ( const any_function_basic& item ); /*! ensures - copies the state of item into *this. - Note that the type of copy is determined by the Storage template argument. E.g. storage_sbo will result in a deep copy, while storage_view would result in *this and item referring to the same underlying function. !*/ void swap ( any_function_basic& item ); /*! ensures - swaps *this and item !*/ }; // ---------------------------------------------------------------------------------------- template < typename T, typename function_type > T& any_cast( any_function_basic& a ) { return a.cast_to(); } /*! ensures - returns a.cast_to() !*/ // ---------------------------------------------------------------------------------------- template < typename T, typename function_type > const T& any_cast( const any_function_basic& a ) { return a.cast_to(); } /*! ensures - returns a.cast_to() !*/ // ---------------------------------------------------------------------------------------- /*!A any_function A version of any_function_basic (defined above) that owns the function it contains. Uses the small buffer optimization to make working with small lambdas faster. !*/ template using any_function = any_function_basic, F>; /*!A any_function_view A version of any_function_basic (defined above) that *DOES NOT* own the function it contains. It merely holds a pointer to the function given to its constructor. !*/ template using any_function_view = any_function_basic; // ---------------------------------------------------------------------------------------- } #endif // DLIB_AnY_FUNCTION_ABSTRACT_H_ ================================================ FILE: dlib/any/any_trainer.h ================================================ // Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_AnY_TRAINER_H_ #define DLIB_AnY_TRAINER_H_ #include "any.h" #include "any_decision_function.h" #include "any_trainer_abstract.h" #include namespace dlib { // ---------------------------------------------------------------------------------------- template < typename sample_type_, typename scalar_type_ = double > class any_trainer { public: using sample_type = sample_type_; using scalar_type = scalar_type_; using mem_manager_type = default_memory_manager; using trained_function_type = any_decision_function; any_trainer() = default; any_trainer(const any_trainer& other) = default; any_trainer& operator=(const any_trainer& other) = default; any_trainer(any_trainer&& other) = default; any_trainer& operator=(any_trainer&& other) = default; template < class T, class T_ = std::decay_t, std::enable_if_t::value, bool> = true > any_trainer ( T&& item ) : storage{std::forward(item)}, train_func{[]( const void* ptr, const std::vector& samples, const std::vector& labels ) -> trained_function_type { const T_& f = *reinterpret_cast(ptr); return f.train(samples, labels); }} { } template < class T, class T_ = std::decay_t, std::enable_if_t::value, bool> = true > any_trainer& operator= ( T&& item ) { if (contains()) storage.unsafe_get() = std::forward(item); else *this = std::move(any_trainer{std::forward(item)}); return *this; } trained_function_type train ( const std::vector& samples, const std::vector& labels ) const { // make sure requires clause is not broken DLIB_ASSERT(is_empty() == false, "\t trained_function_type any_trainer::train()" << "\n\t You can't call train() on an empty any_trainer" << "\n\t this: " << this ); return train_func(storage.get_ptr(), samples, labels); } bool is_empty() const { return storage.is_empty(); } void clear() { storage.clear(); } void swap (any_trainer& item) { std::swap(*this, item); } template bool contains() const { return storage.contains();} template T& cast_to() { return storage.cast_to(); } template const T& cast_to() const { return storage.cast_to(); } template T& get() { return storage.get(); } private: te::storage_heap storage; trained_function_type (*train_func) ( const void* self, const std::vector& samples, const std::vector& labels ) = nullptr; }; // ---------------------------------------------------------------------------------------- template T& any_cast(any_trainer& a) { return a.template cast_to(); } template const T& any_cast(const any_trainer& a) { return a.template cast_to(); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_AnY_TRAINER_H_ ================================================ FILE: dlib/any/any_trainer_abstract.h ================================================ // Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_AnY_TRAINER_ABSTRACT_H_ #ifdef DLIB_AnY_TRAINER_ABSTRACT_H_ #include "any_abstract.h" #include "../algs.h" #include "any_decision_function_abstract.h" #include namespace dlib { // ---------------------------------------------------------------------------------------- template < typename sample_type_, typename scalar_type_ = double > class any_trainer { /*! INITIAL VALUE - is_empty() == true - for all T: contains() == false WHAT THIS OBJECT REPRESENTS This object is a version of dlib::any that is restricted to containing elements which are some kind of object with a .train() method compatible with the following signature: decision_function train( const std::vector& samples, const std::vector& labels ) const Where decision_function is a type capable of being stored in an any_decision_function object. any_trainer is intended to be used to contain objects such as the svm_nu_trainer and other similar types which represent supervised machine learning algorithms. It allows you to write code which contains and processes these trainer objects without needing to know the specific types of trainer objects used. !*/ public: typedef sample_type_ sample_type; typedef scalar_type_ scalar_type; typedef default_memory_manager mem_manager_type; typedef any_decision_function trained_function_type; any_trainer( ); /*! ensures - this object is properly initialized !*/ any_trainer ( const any_trainer& item ); /*! ensures - copies the state of item into *this. - Note that *this and item will contain independent copies of the contents of item. That is, this function performs a deep copy and therefore does not result in *this containing any kind of reference to item. !*/ any_trainer ( any_trainer&& item ); /*! ensures - #item.is_empty() == true - moves item into *this. !*/ template < typename T > any_trainer ( const T& item ); /*! ensures - #contains() == true - #cast_to() == item (i.e. a copy of item will be stored in *this) !*/ void clear ( ); /*! ensures - #*this will have its default value. I.e. #is_empty() == true !*/ template bool contains ( ) const; /*! ensures - if (this object currently contains an object of type T) then - returns true - else - returns false !*/ bool is_empty( ) const; /*! ensures - if (this object contains any kind of object) then - returns false - else - returns true !*/ trained_function_type train ( const std::vector& samples, const std::vector& labels ) const /*! requires - is_empty() == false ensures - Let TRAINER denote the object contained within *this. Then this function performs: return TRAINER.train(samples, labels) !*/ template T& cast_to( ); /*! ensures - if (contains() == true) then - returns a non-const reference to the object contained within *this - else - throws bad_any_cast !*/ template const T& cast_to( ) const; /*! ensures - if (contains() == true) then - returns a const reference to the object contained within *this - else - throws bad_any_cast !*/ template T& get( ); /*! ensures - #is_empty() == false - #contains() == true - if (contains() == true) - returns a non-const reference to the object contained in *this. - else - Constructs an object of type T inside *this - Any previous object stored in this any_trainer object is destructed and its state is lost. - returns a non-const reference to the newly created T object. !*/ any_trainer& operator= ( const any_trainer& item ); /*! ensures - copies the state of item into *this. - Note that *this and item will contain independent copies of the contents of item. That is, this function performs a deep copy and therefore does not result in *this containing any kind of reference to item. !*/ void swap ( any_trainer& item ); /*! ensures - swaps *this and item !*/ }; // ---------------------------------------------------------------------------------------- template < typename sample_type, typename scalar_type > inline void swap ( any_trainer& a, any_trainer& b ) { a.swap(b); } /*! provides a global swap function !*/ // ---------------------------------------------------------------------------------------- template < typename T, typename sample_type, typename scalar_type > T& any_cast( any_trainer& a ) { return a.cast_to(); } /*! ensures - returns a.cast_to() !*/ // ---------------------------------------------------------------------------------------- template < typename T, typename sample_type, typename scalar_type > const T& any_cast( const any_trainer& a ) { return a.cast_to(); } /*! ensures - returns a.cast_to() !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_AnY_TRAINER_ABSTRACT_H_ ================================================ FILE: dlib/any/storage.h ================================================ // Copyright (C) 2022 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_TYPE_ERASURE_H_ #define DLIB_TYPE_ERASURE_H_ #include #include #include #include #include #include "../assert.h" namespace dlib { // ----------------------------------------------------------------------------------------------------- class bad_any_cast : public std::bad_cast { /*! WHAT THIS OBJECT REPRESENTS This object is the exception class used by the storage objects. It is used to indicate when someone attempts to cast a storage object into a type which isn't contained in the object. !*/ public: virtual const char * what() const noexcept { return "bad_any_cast"; } }; // ----------------------------------------------------------------------------------------------------- namespace te { /*! This is used as a SFINAE tool to prevent a function taking a universal reference from binding to some undesired type. For example: template < typename T, T_is_not_this_type = true > void foo(T&&); prevents foo() from binding to an object of type SomeExcludedType. !*/ template using T_is_not_this_type = std::enable_if_t, Storage>::value, bool>; // ----------------------------------------------------------------------------------------------------- template class storage_base { /*! WHAT THIS OBJECT REPRESENTS This class defines functionality common to all type erasure storage objects (defined below in this file). These objects are essentially type-safe versions of a void*. In particular, they are containers which can contain only one object but the object may be of any type. Each storage object implements a different way of storing the underlying object. E.g. on the heap or stack or some other more specialized method. !*/ public: bool is_empty() const /*! ensures - if (this object contains any kind of object) then - returns false - else - returns true !*/ { const Storage& me = *static_cast(this); return me.get_ptr() == nullptr; } template bool contains() const /*! ensures - if (this object currently contains an object of type T) then - returns true - else - returns false !*/ { const Storage& me = *static_cast(this); return !is_empty() && me.type_id() == std::type_index{typeid(T)}; } template T& unsafe_get() /*! requires - contains() == true ensures - returns a reference to the object contained within *this. !*/ { DLIB_ASSERT(contains()); Storage& me = *static_cast(this); return *reinterpret_cast(me.get_ptr()); } template const T& unsafe_get() const /*! requires - contains() == true ensures - returns a const reference to the object contained within *this. !*/ { DLIB_ASSERT(contains()); const Storage& me = *static_cast(this); return *reinterpret_cast(me.get_ptr()); } template T& get( ) /*! ensures - #is_empty() == false - #contains() == true - if (contains() == true) - returns a non-const reference to the object contained in *this. - else - Constructs an object of type T inside *this - Any previous object stored in this any object is destructed and its state is lost. - returns a non-const reference to the newly created T object. !*/ { Storage& me = *static_cast(this); if (!contains()) me = T{}; return unsafe_get(); } template T& cast_to( ) /*! ensures - if (contains() == true) then - returns a non-const reference to the object contained within *this - else - throws bad_any_cast !*/ { if (!contains()) throw bad_any_cast{}; return unsafe_get(); } template const T& cast_to( ) const /*! ensures - if (contains() == true) then - returns a const reference to the object contained within *this - else - throws bad_any_cast !*/ { if (!contains()) throw bad_any_cast{}; return unsafe_get(); } }; // ----------------------------------------------------------------------------------------------------- class storage_heap : public storage_base { public: /*! WHAT THIS OBJECT REPRESENTS This object is a storage type that uses type erasure to erase any type. This particular storage type uses heap allocation only. !*/ storage_heap() = default; /*! ensures - #is_empty() == true - for all T: #contains() == false !*/ template < class T, class T_ = std::decay_t, T_is_not_this_type = true > storage_heap(T &&t) noexcept(std::is_nothrow_constructible::value) /*! ensures - copies or moves the incoming object (depending on the forwarding reference) - #is_empty() == false - #contains>() == true - #unsafe_get() will yield the provided t. !*/ : ptr{new T_{std::forward(t)}}, del{[](void *self) { delete reinterpret_cast(self); }}, copy{[](const void *self) -> void * { return new T_{*reinterpret_cast(self)}; }}, type_id_{[] { return std::type_index{typeid(T_)}; }} { } storage_heap(const storage_heap& other) /*! ensures - #is_empty() == other.is_empty() - if other.is_empty() == false then - underlying object of other is copied using erased type's copy constructor. !*/ : ptr{other.ptr ? other.copy(other.ptr) : nullptr}, del{other.del}, copy{other.copy}, type_id_{other.type_id_} { } storage_heap& operator=(const storage_heap& other) /*! ensures - if is_empty() == false then - destructs the object contained in this class. - #is_empty() == other.is_empty() - if other.is_empty() == false then - underlying object of other is copied using erased type's copy constructor. !*/ { if (this != &other) *this = storage_heap{other}; return *this; } storage_heap(storage_heap&& other) noexcept /*! ensures - The state of other is moved into *this. - #other.is_empty() == true !*/ : ptr{std::exchange(other.ptr, nullptr)}, del{std::exchange(other.del, nullptr)}, copy{std::exchange(other.copy, nullptr)}, type_id_{std::exchange(other.type_id_, nullptr)} { } storage_heap& operator=(storage_heap&& other) noexcept /*! ensures - The state of other is moved into *this. - #other.is_empty() == true - returns *this !*/ { if (this != &other) { clear(); ptr = std::exchange(other.ptr, nullptr); del = std::exchange(other.del, nullptr); copy = std::exchange(other.copy, nullptr); type_id_ = std::exchange(other.type_id_, nullptr); } return *this; } ~storage_heap() /*! ensures - destructs the object contained in *this if one exists. !*/ { if (ptr) del(ptr); } void clear() /*! ensures - #is_empty() == true !*/ { storage_heap{std::move(*this)}; } void* get_ptr() /*! ensures - returns a pointer to the underlying object or nullptr if is_empty() !*/ { return ptr; } const void* get_ptr() const /*! ensures - returns a const pointer to the underlying object or nullptr if is_empty() !*/ { return ptr; } std::type_index type_id() const /*! requires - is_empty() == false ensures - returns the std::type_index of the type contained within this object. I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. !*/ { DLIB_ASSERT(!this->is_empty()); return type_id_(); } private: void* ptr = nullptr; void (*del)(void*) = nullptr; void* (*copy)(const void*) = nullptr; std::type_index (*type_id_)() = nullptr; }; // ----------------------------------------------------------------------------------------------------- template class storage_stack : public storage_base> { /*! WHAT THIS OBJECT REPRESENTS This object is a storage type that uses type erasure to erase any type. This particular storage type uses stack allocation using a template size and alignment. Therefore, only objects whose size and alignment fits the template parameters can be erased and absorbed into this object. Attempting to store a type not representable on the stack with those settings will result in a build error. This object will be capable of storing any type with an alignment requirement that is a divisor of Alignment. !*/ public: storage_stack() = default; /*! ensures - #is_empty() == true - for all T: #contains() == false !*/ template < class T, class T_ = std::decay_t, T_is_not_this_type = true > storage_stack(T &&t) noexcept(std::is_nothrow_constructible::value) /*! ensures - copies or moves the incoming object (depending on the forwarding reference) - #is_empty() == false - #contains>() == true !*/ : del{[](storage_stack& self) { reinterpret_cast(&self.data)->~T_(); self.del = nullptr; self.copy = nullptr; self.move = nullptr; self.type_id_ = nullptr; }}, copy{[](const storage_stack& src, storage_stack& dst) { new (&dst.data) T_{*reinterpret_cast(&src.data)}; dst.del = src.del; dst.copy = src.copy; dst.move = src.move; dst.type_id_ = src.type_id_; }}, move{[](storage_stack& src, storage_stack& dst) { new (&dst.data) T_{std::move(*reinterpret_cast(&src.data))}; dst.del = src.del; dst.copy = src.copy; dst.move = src.move; dst.type_id_ = src.type_id_; }}, type_id_{[] { return std::type_index{typeid(T_)}; }} { static_assert(sizeof(T_) <= Size, "insufficient size"); static_assert(Alignment % alignof(T_) == 0, "bad alignment"); new (&data) T_{std::forward(t)}; } storage_stack(const storage_stack& other) /*! ensures - #is_empty() == other.is_empty() - if other.is_empty() == false then - underlying object of other is copied using erased type's copy constructor. !*/ { if (other.copy) other.copy(other, *this); } storage_stack& operator=(const storage_stack& other) /*! ensures - #is_empty() == other.is_empty() - if is_empty() == false then - destructs the object contained in this class. - if other.is_empty() == false then - underlying object of other is copied using erased type's copy constructor !*/ { if (this != &other) { clear(); if (other.copy) other.copy(other, *this); } return *this; } storage_stack(storage_stack&& other) /*! ensures - #is_empty() == other.is_empty() - if other.is_empty() == false then - underlying object of other is moved using erased type's moved constructor !*/ { if (other.move) other.move(other, *this); } storage_stack& operator=(storage_stack&& other) /*! ensures - if is_empty() == false then - destructs the object contained in this class. - #is_empty() == other.is_empty() - if other.is_empty() == false then - underlying object of other is moved using erased type's moved constructor. This does not make other empty. It will still contain a moved from object of the underlying type in whatever that object's moved from state is. - #other.is_empty() == false !*/ { if (this != &other) { clear(); if (other.move) other.move(other, *this); } return *this; } ~storage_stack() /*! ensures - destructs the object contained in *this if one exists. !*/ { clear(); } void clear() /*! ensures - #is_empty() == true !*/ { if (del) del(*this); } void* get_ptr() /*! ensures - returns a pointer to the underlying object or nullptr if is_empty() !*/ { return del ? (void*)&data : nullptr; } const void* get_ptr() const /*! ensures - returns a const pointer to the underlying object or nullptr if is_empty() !*/ { return del ? (const void*)&data : nullptr; } std::type_index type_id() const /*! requires - is_empty() == false ensures - returns the std::type_index of the type contained within this object. I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. !*/ { DLIB_ASSERT(!this->is_empty()); return type_id_(); } private: alignas(Alignment) unsigned char data[Size]; void (*del)(storage_stack&) = nullptr; void (*copy)(const storage_stack&, storage_stack&) = nullptr; void (*move)(storage_stack&, storage_stack&) = nullptr; std::type_index (*type_id_)() = nullptr; }; // ----------------------------------------------------------------------------------------------------- template class storage_sbo : public storage_base> { /*! WHAT THIS OBJECT REPRESENTS This object is a storage type that uses type erasure to erase any type. This particular storage type uses small buffer optimization (SBO), i.e. optional stack allocation if the erased type has sizeof <= Size and alignment requirements no greater than the given Alignment template value. If not it allocates the object on the heap. !*/ public: // type_fits::value tells us if our SBO can hold T. template struct type_fits : std::integral_constant{}; storage_sbo() = default; /*! ensures - #is_empty() == true - for all T: #contains() == false !*/ template < class T, class T_ = std::decay_t, T_is_not_this_type = true, std::enable_if_t::value, bool> = true > storage_sbo(T &&t) noexcept(std::is_nothrow_constructible::value) /*! ensures - copies or moves the incoming object (depending on the forwarding reference) - #is_empty() == false - #contains>() == true - stack allocation is used !*/ : ptr{new (&data) T_{std::forward(t)}}, del{[](storage_sbo& self) { reinterpret_cast(&self.data)->~T_(); self.ptr = nullptr; self.del = nullptr; self.copy = nullptr; self.move = nullptr; self.type_id_ = nullptr; }}, copy{[](const storage_sbo& src, storage_sbo& dst) { dst.ptr = new (&dst.data) T_{*reinterpret_cast(src.ptr)}; dst.del = src.del; dst.copy = src.copy; dst.move = src.move; dst.type_id_ = src.type_id_; }}, move{[](storage_sbo& src, storage_sbo& dst) { dst.ptr = new (&dst.data) T_{std::move(*reinterpret_cast(src.ptr))}; dst.del = src.del; dst.copy = src.copy; dst.move = src.move; dst.type_id_ = src.type_id_; }}, type_id_{[] { return std::type_index{typeid(T_)}; }} { } template < class T, class T_ = std::decay_t, T_is_not_this_type = true, std::enable_if_t::value, bool> = true > storage_sbo(T &&t) noexcept(std::is_nothrow_constructible::value) /*! ensures - copies or moves the incoming object (depending on the forwarding reference) - #is_empty() == false - #contains>() == true - heap allocation is used !*/ : ptr{new T_{std::forward(t)}}, del{[](storage_sbo& self) { delete reinterpret_cast(self.ptr); self.ptr = nullptr; self.del = nullptr; self.copy = nullptr; self.move = nullptr; self.type_id_ = nullptr; }}, copy{[](const storage_sbo& src, storage_sbo& dst) { dst.ptr = new T_{*reinterpret_cast(src.ptr)}; dst.del = src.del; dst.copy = src.copy; dst.move = src.move; dst.type_id_ = src.type_id_; }}, move{[](storage_sbo& src, storage_sbo& dst) { dst.ptr = std::exchange(src.ptr, nullptr); dst.del = std::exchange(src.del, nullptr); dst.copy = std::exchange(src.copy, nullptr); dst.move = std::exchange(src.move, nullptr); dst.type_id_ = std::exchange(src.type_id_, nullptr); }}, type_id_{[] { return std::type_index{typeid(T_)}; }} { } storage_sbo(const storage_sbo& other) /*! ensures - #is_empty() == other.is_empty() - if other.is_empty() == false then - underlying object of other is copied using erased type's copy constructor !*/ { if (other.copy) other.copy(other, *this); } storage_sbo& operator=(const storage_sbo& other) /*! ensures - if is_empty() == false then - destructs the object contained in this class. - #is_empty() == other.is_empty() - if other.is_empty() == false then - underlying object of other is copied using erased type's copy constructor !*/ { if (this != &other) { clear(); if (other.copy) other.copy(other, *this); } return *this; } storage_sbo(storage_sbo&& other) /*! ensures - #is_empty() == other.is_empty() - if other.is_empty() == false then - if underlying object of other is allocated on stack then - underlying object of other is moved using erased type's moved constructor. This does not make other empty. It will still contain a moved from object of the underlying type in whatever that object's moved from state is. - #other.is_empty() == false - else - storage heap pointer is moved. - #other.is_empty() == true !*/ { if (other.move) other.move(other, *this); } storage_sbo& operator=(storage_sbo&& other) /*! ensures - underlying object is destructed if is_empty() == false - #is_empty() == other.is_empty() - if other.is_empty() == false then - if underlying object of other is allocated on stack then - underlying object of other is moved using erased type's moved constructor. This does not make other empty. It will still contain a moved from object of the underlying type in whatever that object's moved from state is. - #other.is_empty() == false - else - storage heap pointer is moved. - #other.is_empty() == true !*/ { if (this != &other) { clear(); if (other.move) other.move(other, *this); } return *this; } ~storage_sbo() /*! ensures - destructs the object contained in *this if one exists. !*/ { clear(); } void clear() /*! ensures - #is_empty() == true !*/ { if (ptr) del(*this); } void* get_ptr() /*! ensures - returns a pointer to the underlying object or nullptr if is_empty() !*/ { return ptr; } const void* get_ptr() const /*! ensures - returns a const pointer to the underlying object or nullptr if is_empty() !*/ { return ptr; } std::type_index type_id() const /*! requires - is_empty() == false ensures - returns the std::type_index of the type contained within this object. I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. !*/ { DLIB_ASSERT(!this->is_empty()); return type_id_(); } private: alignas(Alignment) unsigned char data[Size]; void* ptr = nullptr; void (*del)(storage_sbo&) = nullptr; void (*copy)(const storage_sbo&, storage_sbo&) = nullptr; void (*move)(storage_sbo&, storage_sbo&) = nullptr; std::type_index (*type_id_)() = nullptr; }; // ----------------------------------------------------------------------------------------------------- class storage_shared : public storage_base { /*! WHAT THIS OBJECT REPRESENTS This object is a storage type that uses type erasure to erase any type. This particular storage type uses std::shared_ptr to store and erase incoming objects. Therefore, it uses heap allocation and reference counting. Moreover, it has the same copying and move semantics as std::shared_ptr. I.e. it results in the underlying object being held by reference rather than by value. !*/ public: storage_shared() = default; /*! ensures - #is_empty() == true - for all T: #contains() == false !*/ template < class T, class T_ = std::decay_t, T_is_not_this_type = true > storage_shared(T &&t) noexcept(std::is_nothrow_constructible::value) /*! ensures - copies or moves the incoming object (depending on the forwarding reference) - #is_empty() == true - #contains>() == true !*/ : ptr{std::make_shared(std::forward(t))}, type_id_{[] { return std::type_index{typeid(T_)}; }} { } // This object has the same copy/move semantics as a std::shared_ptr storage_shared(const storage_shared& other) = default; storage_shared& operator=(const storage_shared& other) = default; storage_shared(storage_shared&& other) noexcept = default; storage_shared& operator=(storage_shared&& other) noexcept = default; void clear() /*! ensures - #is_empty() == true !*/ { ptr = nullptr; type_id_ = nullptr; } void* get_ptr() /*! ensures - returns a pointer to the underlying object or nullptr if is_empty() !*/ { return ptr.get(); } const void* get_ptr() const /*! ensures - returns a const pointer to the underlying object or nullptr if is_empty() !*/ { return ptr.get(); } std::type_index type_id() const /*! requires - is_empty() == false ensures - returns the std::type_index of the type contained within this object. I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. !*/ { DLIB_ASSERT(!this->is_empty()); return type_id_(); } private: std::shared_ptr ptr = nullptr; std::type_index (*type_id_)() = nullptr; }; // ----------------------------------------------------------------------------------------------------- class storage_view : public storage_base { /*! WHAT THIS OBJECT REPRESENTS This object is a storage type that uses type erasure to erase any type. This particular storage type is a view type, similar to std::string_view or std::span. So underlying objects are only ever referenced, not copied, moved or destructed. That is, instances of this object take no ownership of the objects they contain. So they are only valid as long as the contained object exists. So storage_view merely holds a pointer to the underlying object. !*/ public: storage_view() = default; /*! ensures - #is_empty() == true - for all T: #contains() == false !*/ template < class T, class T_ = std::decay_t, T_is_not_this_type = true > storage_view(T &&t) noexcept /*! ensures - #get_ptr() == &t - #is_empty() == false - #contains>() == true !*/ : ptr{&t}, type_id_{[] { return std::type_index{typeid(T_)}; }} { } // This object has the same copy/move semantics as a void*. storage_view(const storage_view& other) = default; storage_view& operator=(const storage_view& other) = default; storage_view(storage_view&& other) noexcept = default; storage_view& operator=(storage_view&& other) noexcept = default; void clear() /*! ensures - #is_empty() == true !*/ { ptr = nullptr; type_id_ = nullptr; } void* get_ptr() /*! ensures - returns a pointer to the underlying object or nullptr if is_empty() !*/ { return ptr; } const void* get_ptr() const /*! ensures - returns a const pointer to the underlying object or nullptr if is_empty() !*/ { return ptr; } std::type_index type_id() const /*! requires - is_empty() == false ensures - returns the std::type_index of the type contained within this object. I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. !*/ { DLIB_ASSERT(!this->is_empty()); return type_id_(); } private: void* ptr = nullptr; std::type_index (*type_id_)() = nullptr; }; // ----------------------------------------------------------------------------------------------------- } } #endif //DLIB_TYPE_ERASURE_H_ ================================================ FILE: dlib/any.h ================================================ // Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_AnY_ #define DLIB_AnY_ #include "any/any.h" #include "any/any_trainer.h" #include "any/any_decision_function.h" #include "any/any_function.h" #endif // DLIB_AnY_ ================================================ FILE: dlib/array/array_kernel.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ARRAY_KERNEl_2_ #define DLIB_ARRAY_KERNEl_2_ #include "array_kernel_abstract.h" #include "../interfaces/enumerable.h" #include "../algs.h" #include "../serialize.h" #include "../sort.h" #include "../is_kind.h" namespace dlib { template < typename T, typename mem_manager = default_memory_manager > class array : public enumerable { /*! INITIAL VALUE - array_size == 0 - max_array_size == 0 - array_elements == 0 - pos == 0 - last_pos == 0 - _at_start == true CONVENTION - array_size == size() - max_array_size == max_size() - if (max_array_size > 0) - array_elements == pointer to max_array_size elements of type T - else - array_elements == 0 - if (array_size > 0) - last_pos == array_elements + array_size - 1 - else - last_pos == 0 - at_start() == _at_start - current_element_valid() == pos != 0 - if (current_element_valid()) then - *pos == element() !*/ public: // These typedefs are here for backwards compatibility with old versions of dlib. typedef array kernel_1a; typedef array kernel_1a_c; typedef array kernel_2a; typedef array kernel_2a_c; typedef array sort_1a; typedef array sort_1a_c; typedef array sort_1b; typedef array sort_1b_c; typedef array sort_2a; typedef array sort_2a_c; typedef array sort_2b; typedef array sort_2b_c; typedef array expand_1a; typedef array expand_1a_c; typedef array expand_1b; typedef array expand_1b_c; typedef array expand_1c; typedef array expand_1c_c; typedef array expand_1d; typedef array expand_1d_c; typedef T type; typedef T value_type; typedef mem_manager mem_manager_type; array ( ) : array_size(0), max_array_size(0), array_elements(0), pos(0), last_pos(0), _at_start(true) {} array(const array&) = delete; array& operator=(array&) = delete; array( array&& item ) : array() { swap(item); } array& operator=( array&& item ) { swap(item); return *this; } explicit array ( size_t new_size ) : array_size(0), max_array_size(0), array_elements(0), pos(0), last_pos(0), _at_start(true) { resize(new_size); } ~array ( ); void clear ( ); inline const T& operator[] ( size_t pos ) const; inline T& operator[] ( size_t pos ); void set_size ( size_t size ); inline size_t max_size( ) const; void set_max_size( size_t max ); void swap ( array& item ); // functions from the enumerable interface inline size_t size ( ) const; inline bool at_start ( ) const; inline void reset ( ) const; bool current_element_valid ( ) const; inline const T& element ( ) const; inline T& element ( ); bool move_next ( ) const; void sort ( ); void resize ( size_t new_size ); const T& back ( ) const; T& back ( ); void pop_back ( ); void pop_back ( T& item ); void push_back ( T& item ); void push_back ( T&& item ); typedef T* iterator; typedef const T* const_iterator; iterator begin() { return array_elements; } const_iterator begin() const { return array_elements; } iterator end() { return array_elements+array_size; } const_iterator end() const { return array_elements+array_size; } private: typename mem_manager::template rebind::other pool; // data members size_t array_size; size_t max_array_size; T* array_elements; mutable T* pos; T* last_pos; mutable bool _at_start; }; template < typename T, typename mem_manager > inline void swap ( array& a, array& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void serialize ( const array& item, std::ostream& out ) { try { serialize(item.max_size(),out); serialize(item.size(),out); for (size_t i = 0; i < item.size(); ++i) serialize(item[i],out); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while serializing object of type array"); } } template < typename T, typename mem_manager > void deserialize ( array& item, std::istream& in ) { try { size_t max_size, size; deserialize(max_size,in); deserialize(size,in); item.set_max_size(max_size); item.set_size(size); for (size_t i = 0; i < size; ++i) deserialize(item[i],in); } catch (serialization_error& e) { item.clear(); throw serialization_error(e.info + "\n while deserializing object of type array"); } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > array:: ~array ( ) { if (array_elements) { pool.deallocate_array(array_elements); } } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void array:: clear ( ) { reset(); last_pos = 0; array_size = 0; if (array_elements) { pool.deallocate_array(array_elements); } array_elements = 0; max_array_size = 0; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > const T& array:: operator[] ( size_t pos ) const { // make sure requires clause is not broken DLIB_ASSERT( pos < this->size() , "\tconst T& array::operator[]" << "\n\tpos must < size()" << "\n\tpos: " << pos << "\n\tsize(): " << this->size() << "\n\tthis: " << this ); return array_elements[pos]; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > T& array:: operator[] ( size_t pos ) { // make sure requires clause is not broken DLIB_ASSERT( pos < this->size() , "\tT& array::operator[]" << "\n\tpos must be < size()" << "\n\tpos: " << pos << "\n\tsize(): " << this->size() << "\n\tthis: " << this ); return array_elements[pos]; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void array:: set_size ( size_t size ) { // make sure requires clause is not broken DLIB_CASSERT(( size <= this->max_size() ), "\tvoid array::set_size" << "\n\tsize must be <= max_size()" << "\n\tsize: " << size << "\n\tmax size: " << this->max_size() << "\n\tthis: " << this ); reset(); array_size = size; if (size > 0) last_pos = array_elements + size - 1; else last_pos = 0; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > size_t array:: size ( ) const { return array_size; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void array:: set_max_size( size_t max ) { reset(); array_size = 0; last_pos = 0; if (max != 0) { // if new max size is different if (max != max_array_size) { if (array_elements) { pool.deallocate_array(array_elements); } // try to get more memroy try { array_elements = pool.allocate_array(max); } catch (...) { array_elements = 0; max_array_size = 0; throw; } max_array_size = max; } } // if the array is being made to be zero else { if (array_elements) pool.deallocate_array(array_elements); max_array_size = 0; array_elements = 0; } } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > size_t array:: max_size ( ) const { return max_array_size; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void array:: swap ( array& item ) { auto array_size_temp = item.array_size; auto max_array_size_temp = item.max_array_size; T* array_elements_temp = item.array_elements; item.array_size = array_size; item.max_array_size = max_array_size; item.array_elements = array_elements; array_size = array_size_temp; max_array_size = max_array_size_temp; array_elements = array_elements_temp; exchange(_at_start,item._at_start); exchange(pos,item.pos); exchange(last_pos,item.last_pos); pool.swap(item.pool); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // enumerable function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > bool array:: at_start ( ) const { return _at_start; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void array:: reset ( ) const { _at_start = true; pos = 0; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > bool array:: current_element_valid ( ) const { return pos != 0; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > const T& array:: element ( ) const { // make sure requires clause is not broken DLIB_ASSERT(this->current_element_valid(), "\tconst T& array::element()" << "\n\tThe current element must be valid if you are to access it." << "\n\tthis: " << this ); return *pos; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > T& array:: element ( ) { // make sure requires clause is not broken DLIB_ASSERT(this->current_element_valid(), "\tT& array::element()" << "\n\tThe current element must be valid if you are to access it." << "\n\tthis: " << this ); return *pos; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > bool array:: move_next ( ) const { if (!_at_start) { if (pos < last_pos) { ++pos; return true; } else { pos = 0; return false; } } else { _at_start = false; if (array_size > 0) { pos = array_elements; return true; } else { return false; } } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // Yet more functions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void array:: sort ( ) { if (this->size() > 1) { // call the quick sort function for arrays that is in algs.h dlib::qsort_array(*this,0,this->size()-1); } this->reset(); } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void array:: resize ( size_t new_size ) { if (this->max_size() < new_size) { array temp; temp.set_max_size(new_size); temp.set_size(new_size); for (size_t i = 0; i < this->size(); ++i) { exchange((*this)[i],temp[i]); } temp.swap(*this); } else { this->set_size(new_size); } } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > T& array:: back ( ) { // make sure requires clause is not broken DLIB_ASSERT( this->size() > 0 , "\tT& array::back()" << "\n\tsize() must be bigger than 0" << "\n\tsize(): " << this->size() << "\n\tthis: " << this ); return (*this)[this->size()-1]; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > const T& array:: back ( ) const { // make sure requires clause is not broken DLIB_ASSERT( this->size() > 0 , "\tconst T& array::back()" << "\n\tsize() must be bigger than 0" << "\n\tsize(): " << this->size() << "\n\tthis: " << this ); return (*this)[this->size()-1]; } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void array:: pop_back ( T& item ) { // make sure requires clause is not broken DLIB_ASSERT( this->size() > 0 , "\tvoid array::pop_back()" << "\n\tsize() must be bigger than 0" << "\n\tsize(): " << this->size() << "\n\tthis: " << this ); exchange(item,(*this)[this->size()-1]); this->set_size(this->size()-1); } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void array:: pop_back ( ) { // make sure requires clause is not broken DLIB_ASSERT( this->size() > 0 , "\tvoid array::pop_back()" << "\n\tsize() must be bigger than 0" << "\n\tsize(): " << this->size() << "\n\tthis: " << this ); this->set_size(this->size()-1); } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void array:: push_back ( T& item ) { if (this->max_size() == this->size()) { // double the size of the array array temp; temp.set_max_size(this->size()*2 + 1); temp.set_size(this->size()+1); for (size_t i = 0; i < this->size(); ++i) { exchange((*this)[i],temp[i]); } exchange(item,temp[temp.size()-1]); temp.swap(*this); } else { this->set_size(this->size()+1); exchange(item,(*this)[this->size()-1]); } } // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void array:: push_back ( T&& item ) { push_back(item); } // ---------------------------------------------------------------------------------------- template struct is_array > { const static bool value = true; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_ARRAY_KERNEl_2_ ================================================ FILE: dlib/array/array_kernel_abstract.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_ARRAY_KERNEl_ABSTRACT_ #ifdef DLIB_ARRAY_KERNEl_ABSTRACT_ #include "../interfaces/enumerable.h" #include "../serialize.h" #include "../algs.h" namespace dlib { template < typename T, typename mem_manager = default_memory_manager > class array : public enumerable { /*! REQUIREMENTS ON T T must have a default constructor. REQUIREMENTS ON mem_manager must be an implementation of memory_manager/memory_manager_kernel_abstract.h or must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h mem_manager::type can be set to anything. POINTERS AND REFERENCES TO INTERNAL DATA front(), back(), swap(), max_size(), set_size(), and operator[] functions do not invalidate pointers or references to internal data. All other functions have no such guarantee. INITIAL VALUE size() == 0 max_size() == 0 ENUMERATION ORDER The enumerator will iterate over the elements of the array in the order (*this)[0], (*this)[1], (*this)[2], ... WHAT THIS OBJECT REPRESENTS This object represents an ordered 1-dimensional array of items, each item is associated with an integer value. The items are numbered from 0 though size() - 1 and the operator[] functions run in constant time. Also note that unless specified otherwise, no member functions of this object throw exceptions. !*/ public: typedef T type; typedef T value_type; typedef mem_manager mem_manager_type; array ( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc or any exception thrown by T's constructor !*/ explicit array ( size_t new_size ); /*! ensures - #*this is properly initialized - #size() == new_size - #max_size() == new_size - All elements of the array will have initial values for their type. throws - std::bad_alloc or any exception thrown by T's constructor !*/ ~array ( ); /*! ensures - all memory associated with *this has been released !*/ array( array&& item ); /*! ensures - move constructs *this from item. Therefore, the state of item is moved into *this and #item has a valid but unspecified state. !*/ array& operator=( array&& item ); /*! ensures - move assigns *this from item. Therefore, the state of item is moved into *this and #item has a valid but unspecified state. - returns a reference to #*this !*/ void clear ( ); /*! ensures - #*this has its initial value throws - std::bad_alloc or any exception thrown by T's constructor if this exception is thrown then the array object is unusable until clear() is called and succeeds !*/ const T& operator[] ( size_t pos ) const; /*! requires - pos < size() ensures - returns a const reference to the element at position pos !*/ T& operator[] ( size_t pos ); /*! requires - pos < size() ensures - returns a non-const reference to the element at position pos !*/ void set_size ( size_t size ); /*! requires - size <= max_size() ensures - #size() == size - any element with index between 0 and size - 1 which was in the array before the call to set_size() retains its value and index. All other elements have undetermined (but valid for their type) values. (e.g. this object might buffer old T objects and reuse them without reinitializing them between calls to set_size()) - #at_start() == true throws - std::bad_alloc or any exception thrown by T's constructor may throw this exception if there is not enough memory and if it does throw then the call to set_size() has no effect !*/ size_t max_size( ) const; /*! ensures - returns the maximum size of *this !*/ void set_max_size( size_t max ); /*! ensures - #max_size() == max - #size() == 0 - #at_start() == true throws - std::bad_alloc or any exception thrown by T's constructor may throw this exception if there is not enough memory and if it does throw then max_size() == 0 !*/ void swap ( array& item ); /*! ensures - swaps *this and item !*/ void sort ( ); /*! requires - T must be a type with that is comparable via operator< ensures - for all elements in #*this the ith element is <= the i+1 element - #at_start() == true throws - std::bad_alloc or any exception thrown by T's constructor data may be lost if sort() throws !*/ void resize ( size_t new_size ); /*! ensures - #size() == new_size - #max_size() == max(new_size,max_size()) - for all i < size() && i < new_size: - #(*this)[i] == (*this)[i] (i.e. All the original elements of *this which were at index values less than new_size are unmodified.) - for all valid i >= size(): - #(*this)[i] has an undefined value (i.e. any new elements of the array have an undefined value) throws - std::bad_alloc or any exception thrown by T's constructor. If an exception is thrown then it has no effect on *this. !*/ const T& back ( ) const; /*! requires - size() != 0 ensures - returns a const reference to (*this)[size()-1] !*/ T& back ( ); /*! requires - size() != 0 ensures - returns a non-const reference to (*this)[size()-1] !*/ void pop_back ( T& item ); /*! requires - size() != 0 ensures - #size() == size() - 1 - swaps (*this)[size()-1] into item - All elements with an index less than size()-1 are unmodified by this operation. !*/ void pop_back ( ); /*! requires - size() != 0 ensures - #size() == size() - 1 - All elements with an index less than size()-1 are unmodified by this operation. !*/ void push_back ( T& item ); /*! ensures - #size() == size()+1 - swaps item into (*this)[#size()-1] - #back() == item - #item has some undefined value (whatever happens to get swapped out of the array) throws - std::bad_alloc or any exception thrown by T's constructor. If an exception is thrown then it has no effect on *this. !*/ void push_back (T&& item) { push_back(item); } /*! enable push_back from rvalues !*/ typedef T* iterator; typedef const T* const_iterator; iterator begin( ); /*! ensures - returns an iterator that points to the first element in this array or end() if the array is empty. !*/ const_iterator begin( ) const; /*! ensures - returns a const iterator that points to the first element in this array or end() if the array is empty. !*/ iterator end( ); /*! ensures - returns an iterator that points to one past the end of the array. !*/ const_iterator end( ) const; /*! ensures - returns a const iterator that points to one past the end of the array. !*/ private: // restricted functions array(array&); // copy constructor array& operator=(array&); // assignment operator }; template < typename T > inline void swap ( array& a, array& b ) { a.swap(b); } /*! provides a global swap function !*/ template < typename T > void serialize ( const array& item, std::ostream& out ); /*! provides serialization support !*/ template < typename T > void deserialize ( array& item, std::istream& in ); /*! provides deserialization support !*/ } #endif // DLIB_ARRAY_KERNEl_ABSTRACT_ ================================================ FILE: dlib/array/array_tools.h ================================================ // Copyright (C) 2013 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ARRAY_tOOLS_H_ #define DLIB_ARRAY_tOOLS_H_ #include "../assert.h" #include "array_tools_abstract.h" namespace dlib { template void split_array ( T& a, T& b, double frac ) { // make sure requires clause is not broken DLIB_ASSERT(0 <= frac && frac <= 1, "\t void split_array()" << "\n\t frac must be between 0 and 1." << "\n\t frac: " << frac ); const unsigned long asize = static_cast(a.size()*frac); const unsigned long bsize = a.size()-asize; b.resize(bsize); for (unsigned long i = 0; i < b.size(); ++i) { swap(b[i], a[i+asize]); } a.resize(asize); } } #endif // DLIB_ARRAY_tOOLS_H_ ================================================ FILE: dlib/array/array_tools_abstract.h ================================================ // Copyright (C) 2013 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_ARRAY_tOOLS_ABSTRACT_H_ #ifdef DLIB_ARRAY_tOOLS_ABSTRACT_H_ #include "array_kernel_abstract.h" namespace dlib { template void split_array ( T& a, T& b, double frac ); /*! requires - 0 <= frac <= 1 - T must be an array type such as dlib::array or std::vector ensures - This function takes the elements of a and splits them into two groups. The first group remains in a and the second group is put into b. The ordering of elements in a is preserved. In particular, concatenating #a with #b will reproduce the original contents of a. - The elements in a are moved around using global swap(). So they must be swappable, but do not need to be copyable. - #a.size() == floor(a.size()*frac) - #b.size() == a.size()-#a.size() !*/ } #endif // DLIB_ARRAY_tOOLS_ABSTRACT_H_ ================================================ FILE: dlib/array.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ARRAy_ #define DLIB_ARRAy_ #include "array/array_kernel.h" #include "array/array_tools.h" #endif // DLIB_ARRAy_ ================================================ FILE: dlib/array2d/array2d_generic_image.h ================================================ // Copyright (C) 2014 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ARRAY2D_GENERIC_iMAGE_Hh_ #define DLIB_ARRAY2D_GENERIC_iMAGE_Hh_ #include "array2d_kernel.h" #include "../image_processing/generic_image.h" namespace dlib { template struct image_traits > { typedef T pixel_type; }; template struct image_traits > { typedef T pixel_type; }; template inline long num_rows( const array2d& img) { return img.nr(); } template inline long num_columns( const array2d& img) { return img.nc(); } template inline void set_image_size( array2d& img, long rows, long cols ) { img.set_size(rows,cols); } template inline void* image_data( array2d& img ) { if (img.size() != 0) return &img[0][0]; else return 0; } template inline const void* image_data( const array2d& img ) { if (img.size() != 0) return &img[0][0]; else return 0; } template inline size_t width_step( const array2d& img ) { return img.width_step(); } } #endif // DLIB_ARRAY2D_GENERIC_iMAGE_Hh_ ================================================ FILE: dlib/array2d/array2d_kernel.h ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ARRAY2D_KERNEl_1_ #define DLIB_ARRAY2D_KERNEl_1_ #include "array2d_kernel_abstract.h" #include "../algs.h" #include "../interfaces/enumerable.h" #include "../serialize.h" #include "../geometry/rectangle.h" namespace dlib { template < typename T, typename mem_manager = default_memory_manager > class array2d : public enumerable { /*! INITIAL VALUE - nc_ == 0 - nr_ == 0 - data == 0 - at_start_ == true - cur == 0 - last == 0 CONVENTION - nc_ == nc() - nr_ == nc() - if (data != 0) then - last == a pointer to the last element in the data array - data == pointer to an array of nc_*nr_ T objects - else - nc_ == 0 - nr_ == 0 - data == 0 - last == 0 - nr_ * nc_ == size() - if (cur == 0) then - current_element_valid() == false - else - current_element_valid() == true - *cur == element() - at_start_ == at_start() !*/ class row_helper; public: // These typedefs are here for backwards compatibility with older versions of dlib. typedef array2d kernel_1a; typedef array2d kernel_1a_c; typedef T type; typedef mem_manager mem_manager_type; typedef T* iterator; typedef const T* const_iterator; // ----------------------------------- class row { /*! CONVENTION - nc_ == nc() - for all x < nc_: - (*this)[x] == data[x] !*/ friend class array2d; friend class row_helper; public: long nc ( ) const { return nc_; } const T& operator[] ( long column ) const { // make sure requires clause is not broken DLIB_ASSERT(column < nc() && column >= 0, "\tconst T& array2d::operator[](long column) const" << "\n\tThe column index given must be less than the number of columns." << "\n\tthis: " << this << "\n\tcolumn: " << column << "\n\tnc(): " << nc() ); return data[column]; } T& operator[] ( long column ) { // make sure requires clause is not broken DLIB_ASSERT(column < nc() && column >= 0, "\tT& array2d::operator[](long column)" << "\n\tThe column index given must be less than the number of columns." << "\n\tthis: " << this << "\n\tcolumn: " << column << "\n\tnc(): " << nc() ); return data[column]; } private: row(T* data_, long cols) : data(data_), nc_(cols) {} row(row&& r) = default; row& operator=(row&& r) = default; T* data = nullptr; long nc_ = 0; // restricted functions row(const row&) = delete; row& operator=(const row&) = delete; }; // ----------------------------------- array2d ( ) : data(0), nc_(0), nr_(0), cur(0), last(0), at_start_(true) { } array2d( long rows, long cols ) : data(0), nc_(0), nr_(0), cur(0), last(0), at_start_(true) { // make sure requires clause is not broken DLIB_ASSERT((cols >= 0 && rows >= 0), "\t array2d::array2d(long rows, long cols)" << "\n\t The array2d can't have negative rows or columns." << "\n\t this: " << this << "\n\t cols: " << cols << "\n\t rows: " << rows ); set_size(rows,cols); } array2d(const array2d&) = delete; // copy constructor array2d& operator=(const array2d&) = delete; // assignment operator #ifdef DLIB_HAS_RVALUE_REFERENCES array2d(array2d&& item) : array2d() { swap(item); } array2d& operator= ( array2d&& rhs ) { swap(rhs); return *this; } #endif virtual ~array2d ( ) { clear(); } long nc ( ) const { return nc_; } long nr ( ) const { return nr_; } row operator[] ( long row_ ) { // make sure requires clause is not broken DLIB_ASSERT(row_ < nr() && row_ >= 0, "\trow array2d::operator[](long row_)" << "\n\tThe row index given must be less than the number of rows." << "\n\tthis: " << this << "\n\trow_: " << row_ << "\n\tnr(): " << nr() ); return row(data+row_*nc_, nc_); } const row operator[] ( long row_ ) const { // make sure requires clause is not broken DLIB_ASSERT(row_ < nr() && row_ >= 0, "\tconst row array2d::operator[](long row_) const" << "\n\tThe row index given must be less than the number of rows." << "\n\tthis: " << this << "\n\trow_: " << row_ << "\n\tnr(): " << nr() ); return row(data+row_*nc_, nc_); } void swap ( array2d& item ) { exchange(data,item.data); exchange(nr_,item.nr_); exchange(nc_,item.nc_); exchange(at_start_,item.at_start_); exchange(cur,item.cur); exchange(last,item.last); pool.swap(item.pool); } void clear ( ) { if (data != 0) { pool.deallocate_array(data); nc_ = 0; nr_ = 0; data = 0; at_start_ = true; cur = 0; last = 0; } } void set_size ( long rows, long cols ); bool at_start ( ) const { return at_start_; } void reset ( ) const { at_start_ = true; cur = 0; } bool current_element_valid ( ) const { return (cur != 0); } const T& element ( ) const { // make sure requires clause is not broken DLIB_ASSERT(current_element_valid() == true, "\tconst T& array2d::element()()" << "\n\tYou can only call element() when you are at a valid one." << "\n\tthis: " << this ); return *cur; } T& element ( ) { // make sure requires clause is not broken DLIB_ASSERT(current_element_valid() == true, "\tT& array2d::element()()" << "\n\tYou can only call element() when you are at a valid one." << "\n\tthis: " << this ); return *cur; } bool move_next ( ) const { if (cur != 0) { if (cur != last) { ++cur; return true; } cur = 0; return false; } else if (at_start_) { cur = data; at_start_ = false; return (data != 0); } else { return false; } } size_t size ( ) const { return static_cast(nc_) * static_cast(nr_); } size_t width_step ( ) const { return nc_*sizeof(T); } iterator begin() { return data; } iterator end() { return data+size(); } const_iterator begin() const { return data; } const_iterator end() const { return data+size(); } private: T* data; long nc_; long nr_; typename mem_manager::template rebind::other pool; mutable T* cur; T* last; mutable bool at_start_; }; // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > inline void swap ( array2d& a, array2d& b ) { a.swap(b); } template < typename T, typename mem_manager > void serialize ( const array2d& item, std::ostream& out ) { try { // The reason the serialization is a little funny is because we are trying to // maintain backwards compatibility with an older serialization format used by // dlib while also encoding things in a way that lets the array2d and matrix // objects have compatible serialization formats. serialize(-item.nr(),out); serialize(-item.nc(),out); item.reset(); while (item.move_next()) serialize(item.element(),out); item.reset(); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while serializing object of type array2d"); } } template < typename T, typename mem_manager > void deserialize ( array2d& item, std::istream& in ) { try { long nr, nc; deserialize(nr,in); deserialize(nc,in); // this is the newer serialization format if (nr < 0 || nc < 0) { nr *= -1; nc *= -1; } else { std::swap(nr,nc); } item.set_size(nr,nc); while (item.move_next()) deserialize(item.element(),in); item.reset(); } catch (serialization_error& e) { item.clear(); throw serialization_error(e.info + "\n while deserializing object of type array2d"); } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename T, typename mem_manager > void array2d:: set_size ( long rows, long cols ) { // make sure requires clause is not broken DLIB_ASSERT((cols >= 0 && rows >= 0) , "\tvoid array2d::set_size(long rows, long cols)" << "\n\tThe array2d can't have negative rows or columns." << "\n\tthis: " << this << "\n\tcols: " << cols << "\n\trows: " << rows ); // set the enumerator back at the start at_start_ = true; cur = 0; // don't do anything if we are already the right size. if (nc_ == cols && nr_ == rows) { return; } nc_ = cols; nr_ = rows; // free any existing memory if (data != 0) { pool.deallocate_array(data); data = 0; } // now setup this object to have the new size try { if (nr_ > 0) { data = pool.allocate_array(nr_*nc_); last = data + nr_*nc_ - 1; } } catch (...) { if (data) pool.deallocate_array(data); data = 0; nc_ = 0; nr_ = 0; last = 0; throw; } } // ---------------------------------------------------------------------------------------- template struct is_array2d > { const static bool value = true; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_ARRAY2D_KERNEl_1_ ================================================ FILE: dlib/array2d/array2d_kernel_abstract.h ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_ARRAY2D_KERNEl_ABSTRACT_ #ifdef DLIB_ARRAY2D_KERNEl_ABSTRACT_ #include "../interfaces/enumerable.h" #include "../serialize.h" #include "../algs.h" #include "../geometry/rectangle_abstract.h" namespace dlib { template < typename T, typename mem_manager = default_memory_manager > class array2d : public enumerable { /*! REQUIREMENTS ON T T must have a default constructor. REQUIREMENTS ON mem_manager must be an implementation of memory_manager/memory_manager_kernel_abstract.h or must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h mem_manager::type can be set to anything. POINTERS AND REFERENCES TO INTERNAL DATA No member functions in this object will invalidate pointers or references to internal data except for the set_size() and clear() member functions. INITIAL VALUE nr() == 0 nc() == 0 ENUMERATION ORDER The enumerator will iterate over the elements of the array starting with row 0 and then proceeding to row 1 and so on. Each row will be fully enumerated before proceeding on to the next row and the elements in a row will be enumerated beginning with the 0th column, then the 1st column and so on. WHAT THIS OBJECT REPRESENTS This object represents a 2-Dimensional array of objects of type T. Also note that unless specified otherwise, no member functions of this object throw exceptions. Finally, note that this object stores its data contiguously and in row major order. Moreover, there is no padding at the end of each row. This means that its width_step() value is always equal to sizeof(type)*nc(). !*/ public: // ---------------------------------------- typedef T type; typedef mem_manager mem_manager_type; typedef T* iterator; typedef const T* const_iterator; // ---------------------------------------- class row { /*! POINTERS AND REFERENCES TO INTERNAL DATA No member functions in this object will invalidate pointers or references to internal data. WHAT THIS OBJECT REPRESENTS This object represents a row of Ts in an array2d object. !*/ public: long nc ( ) const; /*! ensures - returns the number of columns in this row !*/ const T& operator[] ( long column ) const; /*! requires - 0 <= column < nc() ensures - returns a const reference to the T in the given column !*/ T& operator[] ( long column ); /*! requires - 0 <= column < nc() ensures - returns a non-const reference to the T in the given column !*/ private: // restricted functions row(); row& operator=(row&); }; // ---------------------------------------- array2d ( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc !*/ array2d(const array2d&) = delete; // copy constructor array2d& operator=(const array2d&) = delete; // assignment operator array2d( array2d&& item ); /*! ensures - Moves the state of item into *this. - #item is in a valid but unspecified state. !*/ array2d ( long rows, long cols ); /*! requires - rows >= 0 && cols >= 0 ensures - #nc() == cols - #nr() == rows - #at_start() == true - all elements in this array have initial values for their type throws - std::bad_alloc !*/ virtual ~array2d ( ); /*! ensures - all resources associated with *this has been released !*/ void clear ( ); /*! ensures - #*this has an initial value for its type !*/ long nc ( ) const; /*! ensures - returns the number of elements there are in a row. i.e. returns the number of columns in *this !*/ long nr ( ) const; /*! ensures - returns the number of rows in *this !*/ void set_size ( long rows, long cols ); /*! requires - rows >= 0 && cols >= 0 ensures - #nc() == cols - #nr() == rows - #at_start() == true - if (the call to set_size() doesn't change the dimensions of this array) then - all elements in this array retain their values from before this function was called - else - all elements in this array have initial values for their type throws - std::bad_alloc If this exception is thrown then #*this will have an initial value for its type. !*/ row operator[] ( long row_index ); /*! requires - 0 <= row_index < nr() ensures - returns a non-const row of nc() elements that represents the given row_index'th row in *this. !*/ const row operator[] ( long row_index ) const; /*! requires - 0 <= row_index < nr() ensures - returns a const row of nc() elements that represents the given row_index'th row in *this. !*/ void swap ( array2d& item ); /*! ensures - swaps *this and item !*/ array2d& operator= ( array2d&& rhs ); /*! ensures - Moves the state of item into *this. - #item is in a valid but unspecified state. - returns #*this !*/ size_t width_step ( ) const; /*! ensures - returns the size of one row of the image, in bytes. More precisely, return a number N such that: (char*)&item[0][0] + N == (char*)&item[1][0]. - for dlib::array2d objects, the returned value is always equal to sizeof(type)*nc(). However, other objects which implement dlib::array2d style interfaces might have padding at the ends of their rows and therefore might return larger numbers. An example of such an object is the dlib::cv_image. !*/ iterator begin( ); /*! ensures - returns a random access iterator pointing to the first element in this object. - The iterator will iterate over the elements of the object in row major order. !*/ iterator end( ); /*! ensures - returns a random access iterator pointing to one past the end of the last element in this object. !*/ const_iterator begin( ) const; /*! ensures - returns a random access iterator pointing to the first element in this object. - The iterator will iterate over the elements of the object in row major order. !*/ const_iterator end( ) const; /*! ensures - returns a random access iterator pointing to one past the end of the last element in this object. !*/ }; template < typename T, typename mem_manager > inline void swap ( array2d& a, array2d& b ) { a.swap(b); } /*! provides a global swap function !*/ template < typename T, typename mem_manager > void serialize ( const array2d& item, std::ostream& out ); /*! Provides serialization support. Note that the serialization formats used by the dlib::matrix and dlib::array2d objects are compatible. That means you can load the serialized data from one into another and it will work properly. !*/ template < typename T, typename mem_manager > void deserialize ( array2d& item, std::istream& in ); /*! provides deserialization support !*/ } #endif // DLIB_ARRAY2D_KERNEl_ABSTRACT_ ================================================ FILE: dlib/array2d/serialize_pixel_overloads.h ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_ #define DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_ #include "array2d_kernel.h" #include "../pixel.h" namespace dlib { // ---------------------------------------------------------------------------------------- /* This file contains overloads of the serialize functions for array2d object for the case where they contain simple 8bit POD pixel types. In these cases we can perform a much faster serialization by writing data in chunks instead of one pixel at a time (this avoids a lot of function call overhead inside the iostreams). */ // ---------------------------------------------------------------------------------------- template < typename mem_manager > void serialize ( const array2d& item, std::ostream& out ) { try { // The reason the serialization is a little funny is because we are trying to // maintain backwards compatibility with an older serialization format used by // dlib while also encoding things in a way that lets the array2d and matrix // objects have compatible serialization formats. serialize(-item.nr(),out); serialize(-item.nc(),out); COMPILE_TIME_ASSERT(sizeof(rgb_pixel) == 3); if (item.size() != 0) out.write((char*)&item[0][0], sizeof(rgb_pixel)*item.size()); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while serializing object of type array2d"); } } template < typename mem_manager > void deserialize ( array2d& item, std::istream& in ) { try { COMPILE_TIME_ASSERT(sizeof(rgb_pixel) == 3); long nr, nc; deserialize(nr,in); deserialize(nc,in); // this is the newer serialization format if (nr < 0 || nc < 0) { nr *= -1; nc *= -1; } else { std::swap(nr,nc); } item.set_size(nr,nc); if (item.size() != 0) in.read((char*)&item[0][0], sizeof(rgb_pixel)*item.size()); } catch (serialization_error& e) { item.clear(); throw serialization_error(e.info + "\n while deserializing object of type array2d"); } } // ---------------------------------------------------------------------------------------- template < typename mem_manager > void serialize ( const array2d& item, std::ostream& out ) { try { // The reason the serialization is a little funny is because we are trying to // maintain backwards compatibility with an older serialization format used by // dlib while also encoding things in a way that lets the array2d and matrix // objects have compatible serialization formats. serialize(-item.nr(),out); serialize(-item.nc(),out); COMPILE_TIME_ASSERT(sizeof(bgr_pixel) == 3); if (item.size() != 0) out.write((char*)&item[0][0], sizeof(bgr_pixel)*item.size()); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while serializing object of type array2d"); } } template < typename mem_manager > void deserialize ( array2d& item, std::istream& in ) { try { COMPILE_TIME_ASSERT(sizeof(bgr_pixel) == 3); long nr, nc; deserialize(nr,in); deserialize(nc,in); // this is the newer serialization format if (nr < 0 || nc < 0) { nr *= -1; nc *= -1; } else { std::swap(nr,nc); } item.set_size(nr,nc); if (item.size() != 0) in.read((char*)&item[0][0], sizeof(bgr_pixel)*item.size()); } catch (serialization_error& e) { item.clear(); throw serialization_error(e.info + "\n while deserializing object of type array2d"); } } // ---------------------------------------------------------------------------------------- template < typename mem_manager > void serialize ( const array2d& item, std::ostream& out ) { try { // The reason the serialization is a little funny is because we are trying to // maintain backwards compatibility with an older serialization format used by // dlib while also encoding things in a way that lets the array2d and matrix // objects have compatible serialization formats. serialize(-item.nr(),out); serialize(-item.nc(),out); COMPILE_TIME_ASSERT(sizeof(hsi_pixel) == 3); if (item.size() != 0) out.write((char*)&item[0][0], sizeof(hsi_pixel)*item.size()); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while serializing object of type array2d"); } } template < typename mem_manager > void deserialize ( array2d& item, std::istream& in ) { try { COMPILE_TIME_ASSERT(sizeof(hsi_pixel) == 3); long nr, nc; deserialize(nr,in); deserialize(nc,in); // this is the newer serialization format if (nr < 0 || nc < 0) { nr *= -1; nc *= -1; } else { std::swap(nr,nc); } item.set_size(nr,nc); if (item.size() != 0) in.read((char*)&item[0][0], sizeof(hsi_pixel)*item.size()); } catch (serialization_error& e) { item.clear(); throw serialization_error(e.info + "\n while deserializing object of type array2d"); } } // ---------------------------------------------------------------------------------------- template < typename mem_manager > void serialize ( const array2d& item, std::ostream& out ) { try { // The reason the serialization is a little funny is because we are trying to // maintain backwards compatibility with an older serialization format used by // dlib while also encoding things in a way that lets the array2d and matrix // objects have compatible serialization formats. serialize(-item.nr(),out); serialize(-item.nc(),out); COMPILE_TIME_ASSERT(sizeof(rgb_alpha_pixel) == 4); if (item.size() != 0) out.write((char*)&item[0][0], sizeof(rgb_alpha_pixel)*item.size()); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while serializing object of type array2d"); } } template < typename mem_manager > void deserialize ( array2d& item, std::istream& in ) { try { COMPILE_TIME_ASSERT(sizeof(rgb_alpha_pixel) == 4); long nr, nc; deserialize(nr,in); deserialize(nc,in); // this is the newer serialization format if (nr < 0 || nc < 0) { nr *= -1; nc *= -1; } else { std::swap(nr,nc); } item.set_size(nr,nc); if (item.size() != 0) in.read((char*)&item[0][0], sizeof(rgb_alpha_pixel)*item.size()); } catch (serialization_error& e) { item.clear(); throw serialization_error(e.info + "\n while deserializing object of type array2d"); } } // ---------------------------------------------------------------------------------------- template < typename mem_manager > void serialize ( const array2d& item, std::ostream& out ) { try { // The reason the serialization is a little funny is because we are trying to // maintain backwards compatibility with an older serialization format used by // dlib while also encoding things in a way that lets the array2d and matrix // objects have compatible serialization formats. serialize(-item.nr(),out); serialize(-item.nc(),out); if (item.size() != 0) out.write((char*)&item[0][0], sizeof(unsigned char)*item.size()); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while serializing object of type array2d"); } } template < typename mem_manager > void deserialize ( array2d& item, std::istream& in ) { try { long nr, nc; deserialize(nr,in); deserialize(nc,in); // this is the newer serialization format if (nr < 0 || nc < 0) { nr *= -1; nc *= -1; } else { std::swap(nr,nc); } item.set_size(nr,nc); if (item.size() != 0) in.read((char*)&item[0][0], sizeof(unsigned char)*item.size()); } catch (serialization_error& e) { item.clear(); throw serialization_error(e.info + "\n while deserializing object of type array2d"); } } // ---------------------------------------------------------------------------------------- } #endif // DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_ ================================================ FILE: dlib/array2d.h ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ARRAY2d_ #define DLIB_ARRAY2d_ #include "array2d/array2d_kernel.h" #include "array2d/serialize_pixel_overloads.h" #include "array2d/array2d_generic_image.h" #endif // DLIB_ARRAY2d_ ================================================ FILE: dlib/assert.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ASSERt_ #define DLIB_ASSERt_ #include "config.h" #include #include #include "error.h" // ----------------------------- // Use some stuff from boost here // (C) Copyright John Maddock 2001 - 2003. // (C) Copyright Darin Adler 2001. // (C) Copyright Peter Dimov 2001. // (C) Copyright Bill Kempf 2002. // (C) Copyright Jens Maurer 2002. // (C) Copyright David Abrahams 2002 - 2003. // (C) Copyright Gennaro Prota 2003. // (C) Copyright Eric Friedman 2003. // License: Boost Software License See LICENSE.txt for the full license. // #ifndef DLIB_BOOST_JOIN #define DLIB_BOOST_JOIN( X, Y ) DLIB_BOOST_DO_JOIN( X, Y ) #define DLIB_BOOST_DO_JOIN( X, Y ) DLIB_BOOST_DO_JOIN2(X,Y) #define DLIB_BOOST_DO_JOIN2( X, Y ) X##Y #endif // figure out if the compiler has rvalue references. #if defined(__clang__) # if __has_feature(cxx_rvalue_references) # define DLIB_HAS_RVALUE_REFERENCES # endif # if __has_feature(cxx_generalized_initializers) # define DLIB_HAS_INITIALIZER_LISTS # endif #elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__) # define DLIB_HAS_RVALUE_REFERENCES # define DLIB_HAS_INITIALIZER_LISTS #elif defined(_MSC_VER) && _MSC_VER >= 1800 # define DLIB_HAS_INITIALIZER_LISTS # define DLIB_HAS_RVALUE_REFERENCES #elif defined(_MSC_VER) && _MSC_VER >= 1600 # define DLIB_HAS_RVALUE_REFERENCES #elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X) # define DLIB_HAS_RVALUE_REFERENCES # define DLIB_HAS_INITIALIZER_LISTS #endif #if defined(__APPLE__) && defined(__GNUC_LIBSTD__) && ((__GNUC_LIBSTD__-0) * 100 + __GNUC_LIBSTD_MINOR__-0 <= 402) // Apple has not updated libstdc++ in some time and anything under 4.02 does not have for sure. # undef DLIB_HAS_INITIALIZER_LISTS #endif // figure out if the compiler has static_assert. #if defined(__clang__) # if __has_feature(cxx_static_assert) # define DLIB_HAS_STATIC_ASSERT # endif #elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__) # define DLIB_HAS_STATIC_ASSERT #elif defined(_MSC_VER) && _MSC_VER >= 1600 # define DLIB_HAS_STATIC_ASSERT #elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X) # define DLIB_HAS_STATIC_ASSERT #endif // ----------------------------- namespace dlib { template struct compile_time_assert; template <> struct compile_time_assert { enum {value=1}; }; template struct assert_are_same_type; template struct assert_are_same_type {enum{value=1};}; template struct assert_are_not_same_type {enum{value=1}; }; template struct assert_are_not_same_type {}; template struct assert_types_match {enum{value=0};}; template struct assert_types_match {enum{value=1};}; } // gcc 4.8 will warn about unused typedefs. But we use typedefs in some of the compile // time assert macros so we need to make it not complain about them "not being used". #ifdef __GNUC__ #define DLIB_NO_WARN_UNUSED __attribute__ ((unused)) #else #define DLIB_NO_WARN_UNUSED #endif // Use the newer static_assert if it's available since it produces much more readable error // messages. #ifdef DLIB_HAS_STATIC_ASSERT #define COMPILE_TIME_ASSERT(expression) static_assert(expression, "Failed assertion") #define ASSERT_ARE_SAME_TYPE(type1, type2) static_assert(::dlib::assert_types_match::value, "These types should be the same but aren't.") #define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) static_assert(!::dlib::assert_types_match::value, "These types should NOT be the same.") #else #define COMPILE_TIME_ASSERT(expression) \ DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_CTA, __LINE__)[::dlib::compile_time_assert<(bool)(expression)>::value] #define ASSERT_ARE_SAME_TYPE(type1, type2) \ DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_AAST, __LINE__)[::dlib::assert_are_same_type::value] #define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) \ DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_AANST, __LINE__)[::dlib::assert_are_not_same_type::value] #endif // ----------------------------- #if defined DLIB_DISABLE_ASSERTS // if DLIB_DISABLE_ASSERTS is on then never enable DLIB_ASSERT no matter what. #undef ENABLE_ASSERTS #endif #if !defined(DLIB_DISABLE_ASSERTS) && ( defined DEBUG || defined _DEBUG) // make sure ENABLE_ASSERTS is defined if we are indeed using them. #ifndef ENABLE_ASSERTS #define ENABLE_ASSERTS #endif #endif // ----------------------------- #ifdef __GNUC__ // There is a bug in version 4.4.5 of GCC on Ubuntu which causes GCC to segfault // when __PRETTY_FUNCTION__ is used within certain templated functions. So just // don't use it with this version of GCC. # if !(__GNUC__ == 4 && __GNUC_MINOR__ == 4 && __GNUC_PATCHLEVEL__ == 5) # define DLIB_FUNCTION_NAME __PRETTY_FUNCTION__ # else # define DLIB_FUNCTION_NAME "unknown function" # endif #elif defined(_MSC_VER) #define DLIB_FUNCTION_NAME __FUNCSIG__ #else #define DLIB_FUNCTION_NAME "unknown function" #endif #define DLIBM_CASSERT(_exp,_message) \ {if ( !(_exp) ) \ { \ dlib_assert_breakpoint(); \ std::ostringstream dlib_o_out; \ dlib_o_out << "\n\nError detected at line " << __LINE__ << ".\n"; \ dlib_o_out << "Error detected in file " << __FILE__ << ".\n"; \ dlib_o_out << "Error detected in function " << DLIB_FUNCTION_NAME << ".\n\n"; \ dlib_o_out << "Failing expression was " << #_exp << ".\n"; \ dlib_o_out << std::boolalpha << _message << "\n"; \ throw dlib::fatal_error(dlib::EBROKEN_ASSERT,dlib_o_out.str()); \ }} // This macro is not needed if you have a real C++ compiler. It's here to work around bugs in Visual Studio's preprocessor. #define DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(x) x // Make it so the 2nd argument of DLIB_CASSERT is optional. That is, you can call it like // DLIB_CASSERT(exp) or DLIB_CASSERT(exp,message). #define DLIBM_CASSERT_1_ARGS(exp) DLIBM_CASSERT(exp,"") #define DLIBM_CASSERT_2_ARGS(exp,message) DLIBM_CASSERT(exp,message) #define DLIBM_GET_3TH_ARG(arg1, arg2, arg3, ...) arg3 #define DLIBM_CASSERT_CHOOSER(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_GET_3TH_ARG(__VA_ARGS__, DLIBM_CASSERT_2_ARGS, DLIBM_CASSERT_1_ARGS, DLIB_CASSERT_NEVER_USED)) #define DLIB_CASSERT(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_CASSERT_CHOOSER(__VA_ARGS__)(__VA_ARGS__)) #ifdef ENABLE_ASSERTS #define DLIB_ASSERT(...) DLIB_CASSERT(__VA_ARGS__) #define DLIB_IF_ASSERT(exp) exp #else #define DLIB_ASSERT(...) {} #define DLIB_IF_ASSERT(exp) #endif // ---------------------------------------------------------------------------------------- /*!A DLIB_ASSERT_HAS_STANDARD_LAYOUT This macro is meant to cause a compiler error if a type doesn't have a simple memory layout (like a C struct). In particular, types with simple layouts are ones which can be copied via memcpy(). This was called a POD type in C++03 and in C++0x we are looking to check if it is a "standard layout type". Once we can use C++0x we can change this macro to something that uses the std::is_standard_layout type_traits class. See: http://www2.research.att.com/~bs/C++0xFAQ.html#PODs !*/ // Use the fact that in C++03 you can't put non-PODs into a union. #define DLIB_ASSERT_HAS_STANDARD_LAYOUT(type) \ union DLIB_BOOST_JOIN(DAHSL_,__LINE__) { type TYPE_NOT_STANDARD_LAYOUT; }; \ DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DAHSL2_,__LINE__)[sizeof(DLIB_BOOST_JOIN(DAHSL_,__LINE__))]; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // breakpoints extern "C" { inline void dlib_assert_breakpoint( ) {} /*! ensures - this function does nothing It exists just so you can put breakpoints on it in a debugging tool. It is called only when an DLIB_ASSERT or DLIB_CASSERT fails and is about to throw an exception. !*/ } // ----------------------------- #include "stack_trace.h" #endif // DLIB_ASSERt_ ================================================ FILE: dlib/base64/base64_kernel_1.cpp ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BASE64_KERNEL_1_CPp_ #define DLIB_BASE64_KERNEL_1_CPp_ #include "base64_kernel_1.h" #include #include #include namespace dlib { // ---------------------------------------------------------------------------------------- base64::line_ending_type base64:: line_ending ( ) const { return eol_style; } // ---------------------------------------------------------------------------------------- void base64:: set_line_ending ( line_ending_type eol_style_ ) { eol_style = eol_style_; } // ---------------------------------------------------------------------------------------- base64:: base64 ( ) : encode_table(0), decode_table(0), bad_value(100), eol_style(LF) { try { encode_table = new char[64]; decode_table = new unsigned char[UCHAR_MAX]; } catch (...) { if (encode_table) delete [] encode_table; if (decode_table) delete [] decode_table; throw; } // now set up the tables with the right stuff encode_table[0] = 'A'; encode_table[17] = 'R'; encode_table[34] = 'i'; encode_table[51] = 'z'; encode_table[1] = 'B'; encode_table[18] = 'S'; encode_table[35] = 'j'; encode_table[52] = '0'; encode_table[2] = 'C'; encode_table[19] = 'T'; encode_table[36] = 'k'; encode_table[53] = '1'; encode_table[3] = 'D'; encode_table[20] = 'U'; encode_table[37] = 'l'; encode_table[54] = '2'; encode_table[4] = 'E'; encode_table[21] = 'V'; encode_table[38] = 'm'; encode_table[55] = '3'; encode_table[5] = 'F'; encode_table[22] = 'W'; encode_table[39] = 'n'; encode_table[56] = '4'; encode_table[6] = 'G'; encode_table[23] = 'X'; encode_table[40] = 'o'; encode_table[57] = '5'; encode_table[7] = 'H'; encode_table[24] = 'Y'; encode_table[41] = 'p'; encode_table[58] = '6'; encode_table[8] = 'I'; encode_table[25] = 'Z'; encode_table[42] = 'q'; encode_table[59] = '7'; encode_table[9] = 'J'; encode_table[26] = 'a'; encode_table[43] = 'r'; encode_table[60] = '8'; encode_table[10] = 'K'; encode_table[27] = 'b'; encode_table[44] = 's'; encode_table[61] = '9'; encode_table[11] = 'L'; encode_table[28] = 'c'; encode_table[45] = 't'; encode_table[62] = '+'; encode_table[12] = 'M'; encode_table[29] = 'd'; encode_table[46] = 'u'; encode_table[63] = '/'; encode_table[13] = 'N'; encode_table[30] = 'e'; encode_table[47] = 'v'; encode_table[14] = 'O'; encode_table[31] = 'f'; encode_table[48] = 'w'; encode_table[15] = 'P'; encode_table[32] = 'g'; encode_table[49] = 'x'; encode_table[16] = 'Q'; encode_table[33] = 'h'; encode_table[50] = 'y'; // we can now fill out the decode_table by using the encode_table for (int i = 0; i < UCHAR_MAX; ++i) { decode_table[i] = bad_value; } for (unsigned char i = 0; i < 64; ++i) { decode_table[(unsigned char)encode_table[i]] = i; } } // ---------------------------------------------------------------------------------------- base64:: ~base64 ( ) { delete [] encode_table; delete [] decode_table; } // ---------------------------------------------------------------------------------------- void base64:: encode ( std::istream& in_, std::ostream& out_ ) const { using namespace std; streambuf& in = *in_.rdbuf(); streambuf& out = *out_.rdbuf(); unsigned char inbuf[3]; unsigned char outbuf[4]; streamsize status = in.sgetn(reinterpret_cast(&inbuf),3); unsigned char c1, c2, c3, c4, c5, c6; int counter = 19; // while we haven't hit the end of the input stream while (status != 0) { if (counter == 0) { counter = 19; // write a newline char ch; switch (eol_style) { case CR: ch = '\r'; if (out.sputn(&ch,1)!=1) throw std::ios_base::failure("error occurred in the base64 object"); break; case LF: ch = '\n'; if (out.sputn(&ch,1)!=1) throw std::ios_base::failure("error occurred in the base64 object"); break; case CRLF: ch = '\r'; if (out.sputn(&ch,1)!=1) throw std::ios_base::failure("error occurred in the base64 object"); ch = '\n'; if (out.sputn(&ch,1)!=1) throw std::ios_base::failure("error occurred in the base64 object"); break; default: DLIB_CASSERT(false,"this should never happen"); } } --counter; if (status == 3) { // encode the bytes in inbuf to base64 and write them to the output stream c1 = inbuf[0]&0xfc; c2 = inbuf[0]&0x03; c3 = inbuf[1]&0xf0; c4 = inbuf[1]&0x0f; c5 = inbuf[2]&0xc0; c6 = inbuf[2]&0x3f; outbuf[0] = c1>>2; outbuf[1] = (c2<<4)|(c3>>4); outbuf[2] = (c4<<2)|(c5>>6); outbuf[3] = c6; outbuf[0] = encode_table[outbuf[0]]; outbuf[1] = encode_table[outbuf[1]]; outbuf[2] = encode_table[outbuf[2]]; outbuf[3] = encode_table[outbuf[3]]; // write the encoded bytes to the output stream if (out.sputn(reinterpret_cast(&outbuf),4)!=4) { throw std::ios_base::failure("error occurred in the base64 object"); } // get 3 more input bytes status = in.sgetn(reinterpret_cast(&inbuf),3); continue; } else if (status == 2) { // we are at the end of the input stream and need to add some padding // encode the bytes in inbuf to base64 and write them to the output stream c1 = inbuf[0]&0xfc; c2 = inbuf[0]&0x03; c3 = inbuf[1]&0xf0; c4 = inbuf[1]&0x0f; c5 = 0; outbuf[0] = c1>>2; outbuf[1] = (c2<<4)|(c3>>4); outbuf[2] = (c4<<2)|(c5>>6); outbuf[3] = '='; outbuf[0] = encode_table[outbuf[0]]; outbuf[1] = encode_table[outbuf[1]]; outbuf[2] = encode_table[outbuf[2]]; // write the encoded bytes to the output stream if (out.sputn(reinterpret_cast(&outbuf),4)!=4) { throw std::ios_base::failure("error occurred in the base64 object"); } break; } else // in this case status must be 1 { // we are at the end of the input stream and need to add some padding // encode the bytes in inbuf to base64 and write them to the output stream c1 = inbuf[0]&0xfc; c2 = inbuf[0]&0x03; c3 = 0; outbuf[0] = c1>>2; outbuf[1] = (c2<<4)|(c3>>4); outbuf[2] = '='; outbuf[3] = '='; outbuf[0] = encode_table[outbuf[0]]; outbuf[1] = encode_table[outbuf[1]]; // write the encoded bytes to the output stream if (out.sputn(reinterpret_cast(&outbuf),4)!=4) { throw std::ios_base::failure("error occurred in the base64 object"); } break; } } // while (status != 0) // make sure the stream buffer flushes to its I/O channel out.pubsync(); } // ---------------------------------------------------------------------------------------- void base64:: decode ( std::istream& in_, std::ostream& out_ ) const { using namespace std; streambuf& in = *in_.rdbuf(); streambuf& out = *out_.rdbuf(); unsigned char inbuf[4]; unsigned char outbuf[3]; int inbuf_pos = 0; streamsize status = in.sgetn(reinterpret_cast(inbuf),1); // only count this character if it isn't some kind of filler if (status == 1 && decode_table[inbuf[0]] != bad_value ) ++inbuf_pos; unsigned char c1, c2, c3, c4, c5, c6; streamsize outsize; // while we haven't hit the end of the input stream while (status != 0) { // if we have 4 valid characters if (inbuf_pos == 4) { inbuf_pos = 0; // this might be the end of the encoded data so we need to figure out if // there was any padding applied. outsize = 3; if (inbuf[3] == '=') { if (inbuf[2] == '=') outsize = 1; else outsize = 2; } // decode the incoming characters inbuf[0] = decode_table[inbuf[0]]; inbuf[1] = decode_table[inbuf[1]]; inbuf[2] = decode_table[inbuf[2]]; inbuf[3] = decode_table[inbuf[3]]; // now pack these guys into bytes rather than 6 bit chunks c1 = inbuf[0]<<2; c2 = inbuf[1]>>4; c3 = inbuf[1]<<4; c4 = inbuf[2]>>2; c5 = inbuf[2]<<6; c6 = inbuf[3]; outbuf[0] = c1|c2; outbuf[1] = c3|c4; outbuf[2] = c5|c6; // write the encoded bytes to the output stream if (out.sputn(reinterpret_cast(&outbuf),outsize)!=outsize) { throw std::ios_base::failure("error occurred in the base64 object"); } } // get more input characters status = in.sgetn(reinterpret_cast(inbuf + inbuf_pos),1); // only count this character if it isn't some kind of filler if ((decode_table[inbuf[inbuf_pos]] != bad_value || inbuf[inbuf_pos] == '=') && status != 0) ++inbuf_pos; } // while (status != 0) if (inbuf_pos != 0) { ostringstream sout; sout << inbuf_pos << " extra characters were found at the end of the encoded data." << " This may indicate that the data stream has been truncated."; // this happens if we hit EOF in the middle of decoding a 24bit block. throw decode_error(sout.str()); } // make sure the stream buffer flushes to its I/O channel out.pubsync(); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BASE64_KERNEL_1_CPp_ ================================================ FILE: dlib/base64/base64_kernel_1.h ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BASE64_KERNEl_1_ #define DLIB_BASE64_KERNEl_1_ #include "../algs.h" #include "base64_kernel_abstract.h" #include namespace dlib { class base64 { /*! INITIAL VALUE - bad_value == 100 - encode_table == a pointer to an array of 64 chars - where x is a 6 bit value the following is true: - encode_table[x] == the base64 encoding of x - decode_table == a pointer to an array of UCHAR_MAX chars - where x is any char value: - if (x is a valid character in the base64 coding scheme) then - decode_table[x] == the 6 bit value that x encodes - else - decode_table[x] == bad_value CONVENTION - The state of this object never changes so just refer to its initial value. !*/ public: // this is here for backwards compatibility with older versions of dlib. typedef base64 kernel_1a; class decode_error : public dlib::error { public: decode_error( const std::string& e) : error(e) {}}; base64 ( ); virtual ~base64 ( ); enum line_ending_type { CR, // i.e. "\r" LF, // i.e. "\n" CRLF // i.e. "\r\n" }; line_ending_type line_ending ( ) const; void set_line_ending ( line_ending_type eol_style_ ); void encode ( std::istream& in, std::ostream& out ) const; void decode ( std::istream& in, std::ostream& out ) const; private: char* encode_table; unsigned char* decode_table; const unsigned char bad_value; line_ending_type eol_style; // restricted functions base64(base64&); // copy constructor base64& operator=(base64&); // assignment operator }; } #ifdef NO_MAKEFILE #include "base64_kernel_1.cpp" #endif #endif // DLIB_BASE64_KERNEl_1_ ================================================ FILE: dlib/base64/base64_kernel_abstract.h ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_BASE64_KERNEl_ABSTRACT_ #ifdef DLIB_BASE64_KERNEl_ABSTRACT_ #include "../algs.h" #include namespace dlib { class base64 { /*! INITIAL VALUE - line_ending() == LF WHAT THIS OBJECT REPRESENTS This object consists of the two functions encode and decode. These functions allow you to encode and decode data to and from the Base64 Content-Transfer-Encoding defined in section 6.8 of rfc2045. !*/ public: class decode_error : public dlib::error {}; base64 ( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc !*/ virtual ~base64 ( ); /*! ensures - all memory associated with *this has been released !*/ enum line_ending_type { CR, // i.e. "\r" LF, // i.e. "\n" CRLF // i.e. "\r\n" }; line_ending_type line_ending ( ) const; /*! ensures - returns the type of end of line bytes the encoder will use when encoding data to base64 blocks. Note that the ostream object you use might apply some sort of transform to line endings as well. For example, C++ ofstream objects usually convert '\n' into whatever a normal newline is for your platform unless you open a file in binary mode. But aside from file streams the ostream objects usually don't modify the data you pass to them. !*/ void set_line_ending ( line_ending_type eol_style ); /*! ensures - #line_ending() == eol_style !*/ void encode ( std::istream& in, std::ostream& out ) const; /*! ensures - reads all data from in (until EOF is reached) and encodes it and writes it to out throws - std::ios_base::failure if there was a problem writing to out then this exception will be thrown. - any other exception this exception may be thrown if there is any other problem !*/ void decode ( std::istream& in, std::ostream& out ) const; /*! ensures - reads data from in (until EOF is reached), decodes it, and writes it to out. throws - std::ios_base::failure if there was a problem writing to out then this exception will be thrown. - decode_error if an error was detected in the encoded data that prevented it from being correctly decoded then this exception is thrown. - any other exception this exception may be thrown if there is any other problem !*/ private: // restricted functions base64(base64&); // copy constructor base64& operator=(base64&); // assignment operator }; } #endif // DLIB_BASE64_KERNEl_ABSTRACT_ ================================================ FILE: dlib/base64.h ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BASe64_ #define DLIB_BASe64_ #include "base64/base64_kernel_1.h" #endif // DLIB_BASe64_ ================================================ FILE: dlib/bayes_utils/bayes_utils.h ================================================ // Copyright (C) 2007 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BAYES_UTILs_ #define DLIB_BAYES_UTILs_ #include "bayes_utils_abstract.h" #include #include #include #include #include "../string.h" #include "../map.h" #include "../matrix.h" #include "../rand.h" #include "../array.h" #include "../set.h" #include "../algs.h" #include "../noncopyable.h" #include "../graph.h" namespace dlib { // ---------------------------------------------------------------------------------------- class assignment { public: assignment() { } assignment( const assignment& a ) { a.reset(); while (a.move_next()) { unsigned long idx = a.element().key(); unsigned long value = a.element().value(); vals.add(idx,value); } } assignment& operator = ( const assignment& rhs ) { if (this == &rhs) return *this; assignment(rhs).swap(*this); return *this; } void clear() { vals.clear(); } bool operator < ( const assignment& item ) const { if (size() < item.size()) return true; else if (size() > item.size()) return false; reset(); item.reset(); while (move_next()) { item.move_next(); if (element().key() < item.element().key()) return true; else if (element().key() > item.element().key()) return false; else if (element().value() < item.element().value()) return true; else if (element().value() > item.element().value()) return false; } return false; } bool has_index ( unsigned long idx ) const { return vals.is_in_domain(idx); } void add ( unsigned long idx, unsigned long value = 0 ) { // make sure requires clause is not broken DLIB_ASSERT( has_index(idx) == false , "\tvoid assignment::add(idx)" << "\n\tYou can't add the same index to an assignment object more than once" << "\n\tidx: " << idx << "\n\tthis: " << this ); vals.add(idx, value); } unsigned long& operator[] ( const long idx ) { // make sure requires clause is not broken DLIB_ASSERT( has_index(idx) == true , "\tunsigned long assignment::operator[](idx)" << "\n\tYou can't access an index value if it isn't already in the object" << "\n\tidx: " << idx << "\n\tthis: " << this ); return vals[idx]; } const unsigned long& operator[] ( const long idx ) const { // make sure requires clause is not broken DLIB_ASSERT( has_index(idx) == true , "\tunsigned long assignment::operator[](idx)" << "\n\tYou can't access an index value if it isn't already in the object" << "\n\tidx: " << idx << "\n\tthis: " << this ); return vals[idx]; } void swap ( assignment& item ) { vals.swap(item.vals); } void remove ( unsigned long idx ) { // make sure requires clause is not broken DLIB_ASSERT( has_index(idx) == true , "\tunsigned long assignment::remove(idx)" << "\n\tYou can't remove an index value if it isn't already in the object" << "\n\tidx: " << idx << "\n\tthis: " << this ); vals.destroy(idx); } unsigned long size() const { return vals.size(); } void reset() const { vals.reset(); } bool move_next() const { return vals.move_next(); } map_pair& element() { // make sure requires clause is not broken DLIB_ASSERT(current_element_valid() == true, "\tmap_pair& assignment::element()" << "\n\tyou can't access the current element if it doesn't exist" << "\n\tthis: " << this ); return vals.element(); } const map_pair& element() const { // make sure requires clause is not broken DLIB_ASSERT(current_element_valid() == true, "\tconst map_pair& assignment::element() const" << "\n\tyou can't access the current element if it doesn't exist" << "\n\tthis: " << this ); return vals.element(); } bool at_start() const { return vals.at_start(); } bool current_element_valid() const { return vals.current_element_valid(); } friend inline void serialize ( const assignment& item, std::ostream& out ) { serialize(item.vals, out); } friend inline void deserialize ( assignment& item, std::istream& in ) { deserialize(item.vals, in); } private: mutable dlib::map::kernel_1b_c vals; }; inline std::ostream& operator << ( std::ostream& out, const assignment& a ) { a.reset(); out << "("; if (a.move_next()) out << a.element().key() << ":" << a.element().value(); while (a.move_next()) { out << ", " << a.element().key() << ":" << a.element().value(); } out << ")"; return out; } inline void swap ( assignment& a, assignment& b ) { a.swap(b); } // ------------------------------------------------------------------------ class joint_probability_table { /*! INITIAL VALUE - table.size() == 0 CONVENTION - size() == table.size() - probability(a) == table[a] !*/ public: joint_probability_table ( const joint_probability_table& t ) { t.reset(); while (t.move_next()) { assignment a = t.element().key(); double p = t.element().value(); set_probability(a,p); } } joint_probability_table() {} joint_probability_table& operator= ( const joint_probability_table& rhs ) { if (this == &rhs) return *this; joint_probability_table(rhs).swap(*this); return *this; } void set_probability ( const assignment& a, double p ) { // make sure requires clause is not broken DLIB_ASSERT(0.0 <= p && p <= 1.0, "\tvoid& joint_probability_table::set_probability(a,p)" << "\n\tyou have given an invalid probability value" << "\n\tp: " << p << "\n\ta: " << a << "\n\tthis: " << this ); if (table.is_in_domain(a)) { table[a] = p; } else { assignment temp(a); table.add(temp,p); } } bool has_entry_for ( const assignment& a ) const { return table.is_in_domain(a); } void add_probability ( const assignment& a, double p ) { // make sure requires clause is not broken DLIB_ASSERT(0.0 <= p && p <= 1.0, "\tvoid& joint_probability_table::add_probability(a,p)" << "\n\tyou have given an invalid probability value" << "\n\tp: " << p << "\n\ta: " << a << "\n\tthis: " << this ); if (table.is_in_domain(a)) { table[a] += p; if (table[a] > 1.0) table[a] = 1.0; } else { assignment temp(a); table.add(temp,p); } } double probability ( const assignment& a ) const { return table[a]; } void clear() { table.clear(); } size_t size () const { return table.size(); } bool move_next() const { return table.move_next(); } void reset() const { table.reset(); } map_pair& element() { // make sure requires clause is not broken DLIB_ASSERT(current_element_valid() == true, "\tmap_pair& joint_probability_table::element()" << "\n\tyou can't access the current element if it doesn't exist" << "\n\tthis: " << this ); return table.element(); } const map_pair& element() const { // make sure requires clause is not broken DLIB_ASSERT(current_element_valid() == true, "\tconst map_pair& joint_probability_table::element() const" << "\n\tyou can't access the current element if it doesn't exist" << "\n\tthis: " << this ); return table.element(); } bool at_start() const { return table.at_start(); } bool current_element_valid() const { return table.current_element_valid(); } template void marginalize ( const T& vars, joint_probability_table& out ) const { out.clear(); double p; reset(); while (move_next()) { assignment a; const assignment& asrc = element().key(); p = element().value(); asrc.reset(); while (asrc.move_next()) { if (vars.is_member(asrc.element().key())) a.add(asrc.element().key(), asrc.element().value()); } out.add_probability(a,p); } } void marginalize ( const unsigned long var, joint_probability_table& out ) const { out.clear(); double p; reset(); while (move_next()) { assignment a; const assignment& asrc = element().key(); p = element().value(); asrc.reset(); while (asrc.move_next()) { if (var == asrc.element().key()) a.add(asrc.element().key(), asrc.element().value()); } out.add_probability(a,p); } } void normalize ( ) { double sum = 0; reset(); while (move_next()) sum += element().value(); reset(); while (move_next()) element().value() /= sum; } void swap ( joint_probability_table& item ) { table.swap(item.table); } friend inline void serialize ( const joint_probability_table& item, std::ostream& out ) { serialize(item.table, out); } friend inline void deserialize ( joint_probability_table& item, std::istream& in ) { deserialize(item.table, in); } private: dlib::map::kernel_1b_c table; }; inline void swap ( joint_probability_table& a, joint_probability_table& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- class conditional_probability_table : noncopyable { /*! INITIAL VALUE - table.size() == 0 CONVENTION - if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0) then - has_entry_for(value,ps) == true - probability(value,ps) == table[ps](value) - else - has_entry_for(value,ps) == false - num_values() == num_vals !*/ public: conditional_probability_table() { clear(); } void set_num_values ( unsigned long num ) { num_vals = num; table.clear(); } bool has_entry_for ( unsigned long value, const assignment& ps ) const { if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0) return true; else return false; } unsigned long num_values ( ) const { return num_vals; } void set_probability ( unsigned long value, const assignment& ps, double p ) { // make sure requires clause is not broken DLIB_ASSERT( value < num_values() && 0.0 <= p && p <= 1.0 , "\tvoid conditional_probability_table::set_probability()" << "\n\tinvalid arguments to set_probability" << "\n\tvalue: " << value << "\n\tnum_values(): " << num_values() << "\n\tp: " << p << "\n\tps: " << ps << "\n\tthis: " << this ); if (table.is_in_domain(ps)) { table[ps](value) = p; } else { matrix dist(num_vals); set_all_elements(dist,-1); dist(value) = p; assignment temp(ps); table.add(temp,dist); } } double probability( unsigned long value, const assignment& ps ) const { // make sure requires clause is not broken DLIB_ASSERT( value < num_values() && has_entry_for(value,ps) , "\tvoid conditional_probability_table::probability()" << "\n\tinvalid arguments to probability" << "\n\tvalue: " << value << "\n\tnum_values(): " << num_values() << "\n\tps: " << ps << "\n\tthis: " << this ); return table[ps](value); } void clear() { table.clear(); num_vals = 0; } void empty_table () { table.clear(); } void swap ( conditional_probability_table& item ) { exchange(num_vals, item.num_vals); table.swap(item.table); } friend inline void serialize ( const conditional_probability_table& item, std::ostream& out ) { serialize(item.table, out); serialize(item.num_vals, out); } friend inline void deserialize ( conditional_probability_table& item, std::istream& in ) { deserialize(item.table, in); deserialize(item.num_vals, in); } private: dlib::map >::kernel_1b_c table; unsigned long num_vals; }; inline void swap ( conditional_probability_table& a, conditional_probability_table& b ) { a.swap(b); } // ------------------------------------------------------------------------ class bayes_node : noncopyable { public: bayes_node () { is_instantiated = false; value_ = 0; } unsigned long value ( ) const { return value_;} void set_value ( unsigned long new_value ) { // make sure requires clause is not broken DLIB_ASSERT( new_value < table().num_values(), "\tvoid bayes_node::set_value(new_value)" << "\n\tnew_value must be less than the number of possible values for this node" << "\n\tnew_value: " << new_value << "\n\ttable().num_values(): " << table().num_values() << "\n\tthis: " << this ); value_ = new_value; } conditional_probability_table& table ( ) { return table_; } const conditional_probability_table& table ( ) const { return table_; } bool is_evidence ( ) const { return is_instantiated; } void set_as_nonevidence ( ) { is_instantiated = false; } void set_as_evidence ( ) { is_instantiated = true; } void swap ( bayes_node& item ) { exchange(value_, item.value_); exchange(is_instantiated, item.is_instantiated); table_.swap(item.table_); } friend inline void serialize ( const bayes_node& item, std::ostream& out ) { serialize(item.value_, out); serialize(item.is_instantiated, out); serialize(item.table_, out); } friend inline void deserialize ( bayes_node& item, std::istream& in ) { deserialize(item.value_, in); deserialize(item.is_instantiated, in); deserialize(item.table_, in); } private: unsigned long value_; bool is_instantiated; conditional_probability_table table_; }; inline void swap ( bayes_node& a, bayes_node& b ) { a.swap(b); } // ------------------------------------------------------------------------ namespace bayes_node_utils { template unsigned long node_num_values ( const T& bn, unsigned long n ) { // make sure requires clause is not broken DLIB_ASSERT( n < bn.number_of_nodes(), "\tvoid bayes_node_utils::node_num_values(bn, n)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() ); return bn.node(n).data.table().num_values(); } // ---------------------------------------------------------------------------------------- template void set_node_value ( T& bn, unsigned long n, unsigned long val ) { // make sure requires clause is not broken DLIB_ASSERT( n < bn.number_of_nodes() && val < node_num_values(bn,n), "\tvoid bayes_node_utils::set_node_value(bn, n, val)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tval: " << val << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) ); bn.node(n).data.set_value(val); } // ---------------------------------------------------------------------------------------- template unsigned long node_value ( const T& bn, unsigned long n ) { // make sure requires clause is not broken DLIB_ASSERT( n < bn.number_of_nodes(), "\tunsigned long bayes_node_utils::node_value(bn, n)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() ); return bn.node(n).data.value(); } // ---------------------------------------------------------------------------------------- template bool node_is_evidence ( const T& bn, unsigned long n ) { // make sure requires clause is not broken DLIB_ASSERT( n < bn.number_of_nodes(), "\tbool bayes_node_utils::node_is_evidence(bn, n)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() ); return bn.node(n).data.is_evidence(); } // ---------------------------------------------------------------------------------------- template void set_node_as_evidence ( T& bn, unsigned long n ) { // make sure requires clause is not broken DLIB_ASSERT( n < bn.number_of_nodes(), "\tvoid bayes_node_utils::set_node_as_evidence(bn, n)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() ); bn.node(n).data.set_as_evidence(); } // ---------------------------------------------------------------------------------------- template void set_node_as_nonevidence ( T& bn, unsigned long n ) { // make sure requires clause is not broken DLIB_ASSERT( n < bn.number_of_nodes(), "\tvoid bayes_node_utils::set_node_as_nonevidence(bn, n)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() ); bn.node(n).data.set_as_nonevidence(); } // ---------------------------------------------------------------------------------------- template void set_node_num_values ( T& bn, unsigned long n, unsigned long num ) { // make sure requires clause is not broken DLIB_ASSERT( n < bn.number_of_nodes(), "\tvoid bayes_node_utils::set_node_num_values(bn, n, num)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() ); bn.node(n).data.table().set_num_values(num); } // ---------------------------------------------------------------------------------------- template double node_probability ( const T& bn, unsigned long n, unsigned long value, const assignment& parents ) { // make sure requires clause is not broken DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n), "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tvalue: " << value << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) ); DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(), "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tparents.size(): " << parents.size() << "\n\tb.node(n).number_of_parents(): " << bn.node(n).number_of_parents() ); #ifdef ENABLE_ASSERTS parents.reset(); while (parents.move_next()) { const unsigned long x = parents.element().key(); DLIB_ASSERT( bn.has_edge(x, n), "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tx: " << x ); DLIB_ASSERT( parents[x] < node_num_values(bn,x), "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tx: " << x << "\n\tparents[x]: " << parents[x] << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) ); } #endif return bn.node(n).data.table().probability(value, parents); } // ---------------------------------------------------------------------------------------- template void set_node_probability ( T& bn, unsigned long n, unsigned long value, const assignment& parents, double p ) { // make sure requires clause is not broken DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n), "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tp: " << p << "\n\tvalue: " << value << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) ); DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(), "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tp: " << p << "\n\tparents.size(): " << parents.size() << "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents() ); DLIB_ASSERT( 0.0 <= p && p <= 1.0, "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tp: " << p ); #ifdef ENABLE_ASSERTS parents.reset(); while (parents.move_next()) { const unsigned long x = parents.element().key(); DLIB_ASSERT( bn.has_edge(x, n), "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tx: " << x ); DLIB_ASSERT( parents[x] < node_num_values(bn,x), "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tx: " << x << "\n\tparents[x]: " << parents[x] << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) ); } #endif bn.node(n).data.table().set_probability(value,parents,p); } // ---------------------------------------------------------------------------------------- template const assignment node_first_parent_assignment ( const T& bn, unsigned long n ) { // make sure requires clause is not broken DLIB_ASSERT( n < bn.number_of_nodes(), "\tconst assignment bayes_node_utils::node_first_parent_assignment(bn, n)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n ); assignment a; const unsigned long num_parents = bn.node(n).number_of_parents(); for (unsigned long i = 0; i < num_parents; ++i) { a.add(bn.node(n).parent(i).index(), 0); } return a; } // ---------------------------------------------------------------------------------------- template bool node_next_parent_assignment ( const T& bn, unsigned long n, assignment& a ) { // make sure requires clause is not broken DLIB_ASSERT( n < bn.number_of_nodes(), "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n ); DLIB_ASSERT( a.size() == bn.node(n).number_of_parents(), "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\ta.size(): " << a.size() << "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents() ); #ifdef ENABLE_ASSERTS a.reset(); while (a.move_next()) { const unsigned long x = a.element().key(); DLIB_ASSERT( bn.has_edge(x, n), "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tx: " << x ); DLIB_ASSERT( a[x] < node_num_values(bn,x), "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tx: " << x << "\n\ta[x]: " << a[x] << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) ); } #endif // basically this loop just adds 1 to the assignment but performs // carries if necessary for (unsigned long p = 0; p < a.size(); ++p) { const unsigned long pindex = bn.node(n).parent(p).index(); a[pindex] += 1; // if we need to perform a carry if (a[pindex] >= node_num_values(bn,pindex)) { a[pindex] = 0; } else { // no carry necessary so we are done return true; } } // we got through the entire loop which means a carry propagated all the way out // so there must not be any more valid assignments left return false; } // ---------------------------------------------------------------------------------------- template bool node_cpt_filled_out ( const T& bn, unsigned long n ) { // make sure requires clause is not broken DLIB_ASSERT( n < bn.number_of_nodes(), "\tbool bayes_node_utils::node_cpt_filled_out(bn, n)" << "\n\tInvalid arguments to this function" << "\n\tn: " << n << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() ); const unsigned long num_values = node_num_values(bn,n); const conditional_probability_table& table = bn.node(n).data.table(); // now loop over all the possible parent assignments for this node assignment a(node_first_parent_assignment(bn,n)); do { double sum = 0; // make sure that this assignment has an entry for all the values this node can take one for (unsigned long value = 0; value < num_values; ++value) { if (table.has_entry_for(value,a) == false) return false; else sum += table.probability(value,a); } // check if the sum of probabilities equals 1 as it should if (std::abs(sum-1.0) > 1e-5) return false; } while (node_next_parent_assignment(bn,n,a)); return true; } } // ---------------------------------------------------------------------------------------- class bayesian_network_gibbs_sampler : noncopyable { public: bayesian_network_gibbs_sampler () { rnd.set_seed(cast_to_string(std::time(0))); } template < typename T > void sample_graph ( T& bn ) { using namespace bayes_node_utils; for (unsigned long n = 0; n < bn.number_of_nodes(); ++n) { if (node_is_evidence(bn, n)) continue; samples.set_size(node_num_values(bn,n)); // obtain the probability distribution for this node for (long i = 0; i < samples.nc(); ++i) { set_node_value(bn, n, i); samples(i) = node_probability(bn, n); for (unsigned long j = 0; j < bn.node(n).number_of_children(); ++j) samples(i) *= node_probability(bn, bn.node(n).child(j).index()); } //normalize samples samples /= sum(samples); // select a random point in the probability distribution double prob = rnd.get_random_double(); // now find the point in the distribution this probability corresponds to long j; for (j = 0; j < samples.nc()-1; ++j) { if (prob <= samples(j)) break; else prob -= samples(j); } set_node_value(bn, n, j); } } private: template < typename T > double node_probability ( const T& bn, unsigned long n ) /*! requires - n < bn.number_of_nodes() ensures - computes the probability of node n having its current value given the current values of its parents in the network bn !*/ { v.clear(); for (unsigned long i = 0; i < bn.node(n).number_of_parents(); ++i) { v.add(bn.node(n).parent(i).index(), bn.node(n).parent(i).data.value()); } return bn.node(n).data.table().probability(bn.node(n).data.value(), v); } assignment v; dlib::rand rnd; matrix samples; }; // ---------------------------------------------------------------------------------------- namespace bayesian_network_join_tree_helpers { class bnjt { /*! this object is the base class used in this pimpl idiom !*/ public: virtual ~bnjt() {} virtual const matrix probability( unsigned long idx ) const = 0; }; template class bnjt_impl : public bnjt { /*! This object is the implementation in the pimpl idiom !*/ public: bnjt_impl ( const T& bn, const U& join_tree ) { create_bayesian_network_join_tree(bn, join_tree, join_tree_values); cliques.resize(bn.number_of_nodes()); // figure out which cliques contain each node for (unsigned long i = 0; i < cliques.size(); ++i) { // find the smallest clique that contains node with index i unsigned long smallest_clique = 0; unsigned long size = std::numeric_limits::max(); for (unsigned long n = 0; n < join_tree.number_of_nodes(); ++n) { if (join_tree.node(n).data.is_member(i) && join_tree.node(n).data.size() < size) { size = join_tree.node(n).data.size(); smallest_clique = n; } } cliques[i] = smallest_clique; } } virtual const matrix probability( unsigned long idx ) const { join_tree_values.node(cliques[idx]).data.marginalize(idx, table); table.normalize(); var.clear(); var.add(idx); dist.set_size(table.size()); // read the probabilities out of the table and into the row matrix for (unsigned long i = 0; i < table.size(); ++i) { var[idx] = i; dist(i) = table.probability(var); } return dist; } private: graph< joint_probability_table, joint_probability_table >::kernel_1a_c join_tree_values; array cliques; mutable joint_probability_table table; mutable assignment var; mutable matrix dist; // ---------------------------------------------------------------------------------------- template bool set_contains_all_parents_of_node ( const set_type& set, const node_type& node ) { for (unsigned long i = 0; i < node.number_of_parents(); ++i) { if (set.is_member(node.parent(i).index()) == false) return false; } return true; } // ---------------------------------------------------------------------------------------- template < typename V > void pass_join_tree_message ( const U& join_tree, V& bn_join_tree , unsigned long from, unsigned long to ) { using namespace bayes_node_utils; const typename U::edge_type& e = edge(join_tree, from, to); typename V::edge_type& old_s = edge(bn_join_tree, from, to); typedef typename V::edge_type joint_prob_table; joint_prob_table new_s; bn_join_tree.node(from).data.marginalize(e, new_s); joint_probability_table temp(new_s); // divide new_s by old_s and store the result in temp. // if old_s is empty then that is the same as if it was all 1s // so we don't have to do this if that is the case. if (old_s.size() > 0) { temp.reset(); old_s.reset(); while (temp.move_next()) { old_s.move_next(); if (old_s.element().value() != 0) temp.element().value() /= old_s.element().value(); } } // now multiply temp by d and store the results in d joint_probability_table& d = bn_join_tree.node(to).data; d.reset(); while (d.move_next()) { assignment a; const assignment& asrc = d.element().key(); asrc.reset(); while (asrc.move_next()) { if (e.is_member(asrc.element().key())) a.add(asrc.element().key(), asrc.element().value()); } d.element().value() *= temp.probability(a); } // store new_s in old_s new_s.swap(old_s); } // ---------------------------------------------------------------------------------------- template < typename V > void create_bayesian_network_join_tree ( const T& bn, const U& join_tree, V& bn_join_tree ) /*! requires - bn is a proper bayesian network - join_tree is the join tree for that bayesian network ensures - bn_join_tree == the output of the join tree algorithm for bayesian network inference. So each node in this graph contains a joint_probability_table for the clique in the corresponding node in the join_tree graph. !*/ { using namespace bayes_node_utils; bn_join_tree.clear(); copy_graph_structure(join_tree, bn_join_tree); // we need to keep track of which node is "in" each clique for the purposes of // initializing the tables in each clique. So this vector will be used to do that // and a value of join_tree.number_of_nodes() means that the node with // that index is unassigned. std::vector node_assigned_to(bn.number_of_nodes(),join_tree.number_of_nodes()); // populate evidence with all the evidence node indices and their values dlib::map::kernel_1b_c evidence; for (unsigned long i = 0; i < bn.number_of_nodes(); ++i) { if (node_is_evidence(bn, i)) { unsigned long idx = i; unsigned long value = node_value(bn, i); evidence.add(idx,value); } } // initialize the bn join tree for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i) { bool contains_evidence = false; std::vector indices; assignment value; // loop over all the nodes in this clique in the join tree. In this loop // we are making an assignment with all the values of the nodes it represents set to 0 join_tree.node(i).data.reset(); while (join_tree.node(i).data.move_next()) { const unsigned long idx = join_tree.node(i).data.element(); indices.push_back(idx); value.add(idx); if (evidence.is_in_domain(join_tree.node(i).data.element())) contains_evidence = true; } // now loop over all possible combinations of values that the nodes this // clique in the join tree can take on. We do this by counting by one through all // legal values bool more_assignments = true; while (more_assignments) { bn_join_tree.node(i).data.set_probability(value,1); // account for any evidence if (contains_evidence) { // loop over all the nodes in this cluster for (unsigned long j = 0; j < indices.size(); ++j) { // if the current node is an evidence node if (evidence.is_in_domain(indices[j])) { const unsigned long idx = indices[j]; const unsigned long evidence_value = evidence[idx]; if (value[idx] != evidence_value) bn_join_tree.node(i).data.set_probability(value , 0); } } } // now check if any of the nodes in this cluster also have their parents in this cluster join_tree.node(i).data.reset(); while (join_tree.node(i).data.move_next()) { const unsigned long idx = join_tree.node(i).data.element(); // if this clique contains all the parents of this node and also hasn't // been assigned to another clique if (set_contains_all_parents_of_node(join_tree.node(i).data, bn.node(idx)) && (i == node_assigned_to[idx] || node_assigned_to[idx] == join_tree.number_of_nodes()) ) { // note that this node is now assigned to this clique node_assigned_to[idx] = i; // node idx has all its parents in the cluster assignment parent_values; for (unsigned long j = 0; j < bn.node(idx).number_of_parents(); ++j) { const unsigned long pidx = bn.node(idx).parent(j).index(); parent_values.add(pidx, value[pidx]); } double temp = bn_join_tree.node(i).data.probability(value); bn_join_tree.node(i).data.set_probability(value, temp * node_probability(bn, idx, value[idx], parent_values)); } } // now advance the value variable to its next possible state if there is one more_assignments = false; value.reset(); while (value.move_next()) { value.element().value() += 1; // if overflow if (value.element().value() == node_num_values(bn, value.element().key())) { value.element().value() = 0; } else { more_assignments = true; break; } } } // end while (more_assignments) } // the tree is now initialized. Now all we need to do is perform the propagation and // we are done dlib::array::compare_1b_c> remaining_msg_to_send; dlib::array::compare_1b_c> remaining_msg_to_receive; remaining_msg_to_receive.resize(join_tree.number_of_nodes()); remaining_msg_to_send.resize(join_tree.number_of_nodes()); for (unsigned long i = 0; i < remaining_msg_to_receive.size(); ++i) { for (unsigned long j = 0; j < join_tree.node(i).number_of_neighbors(); ++j) { const unsigned long idx = join_tree.node(i).neighbor(j).index(); unsigned long temp; temp = idx; remaining_msg_to_receive[i].add(temp); temp = idx; remaining_msg_to_send[i].add(temp); } } // now remaining_msg_to_receive[i] contains all the nodes that node i hasn't yet received // a message from. // we will consider node 0 to be the root node. bool message_sent = true; while (message_sent) { message_sent = false; for (unsigned long i = 1; i < remaining_msg_to_send.size(); ++i) { // if node i hasn't sent any messages but has received all but one then send a message to the one // node who hasn't sent i a message if (remaining_msg_to_send[i].size() == join_tree.node(i).number_of_neighbors() && remaining_msg_to_receive[i].size() == 1) { unsigned long to; // get the last remaining thing from this set remaining_msg_to_receive[i].remove_any(to); // send the message pass_join_tree_message(join_tree, bn_join_tree, i, to); // record that we sent this message remaining_msg_to_send[i].destroy(to); remaining_msg_to_receive[to].destroy(i); // put to back in since we still need to receive it remaining_msg_to_receive[i].add(to); message_sent = true; } else if (remaining_msg_to_receive[i].size() == 0 && remaining_msg_to_send[i].size() > 0) { unsigned long to; remaining_msg_to_send[i].remove_any(to); remaining_msg_to_receive[to].destroy(i); pass_join_tree_message(join_tree, bn_join_tree, i, to); message_sent = true; } } if (remaining_msg_to_receive[0].size() == 0) { // send a message to all of the root nodes neighbors unless we have already sent out he messages while (remaining_msg_to_send[0].size() > 0) { unsigned long to; remaining_msg_to_send[0].remove_any(to); remaining_msg_to_receive[to].destroy(0); pass_join_tree_message(join_tree, bn_join_tree, 0, to); message_sent = true; } } } } }; } class bayesian_network_join_tree : noncopyable { /*! use the pimpl idiom to push the template arguments from the class level to the constructor level !*/ public: template < typename T, typename U > bayesian_network_join_tree ( const T& bn, const U& join_tree ) { // make sure requires clause is not broken DLIB_ASSERT( bn.number_of_nodes() > 0 , "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" << "\n\tYou have given an invalid bayesian network" << "\n\tthis: " << this ); DLIB_ASSERT( is_join_tree(bn, join_tree) == true , "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" << "\n\tYou have given an invalid join tree for the supplied bayesian network" << "\n\tthis: " << this ); DLIB_ASSERT( graph_contains_length_one_cycle(bn) == false, "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" << "\n\tYou have given an invalid bayesian network" << "\n\tthis: " << this ); DLIB_ASSERT( graph_is_connected(bn) == true, "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" << "\n\tYou have given an invalid bayesian network" << "\n\tthis: " << this ); #ifdef ENABLE_ASSERTS for (unsigned long i = 0; i < bn.number_of_nodes(); ++i) { DLIB_ASSERT(bayes_node_utils::node_cpt_filled_out(bn,i) == true, "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" << "\n\tYou have given an invalid bayesian network. " << "\n\tYou must finish filling out the conditional_probability_table of node " << i << "\n\tthis: " << this ); } #endif impl.reset(new bayesian_network_join_tree_helpers::bnjt_impl(bn, join_tree)); num_nodes = bn.number_of_nodes(); } const matrix probability( unsigned long idx ) const { // make sure requires clause is not broken DLIB_ASSERT( idx < number_of_nodes() , "\tconst matrix bayesian_network_join_tree::probability(idx)" << "\n\tYou have specified an invalid node index" << "\n\tidx: " << idx << "\n\tnumber_of_nodes(): " << number_of_nodes() << "\n\tthis: " << this ); return impl->probability(idx); } unsigned long number_of_nodes ( ) const { return num_nodes; } void swap ( bayesian_network_join_tree& item ) { exchange(num_nodes, item.num_nodes); impl.swap(item.impl); } private: std::unique_ptr impl; unsigned long num_nodes; }; inline void swap ( bayesian_network_join_tree& a, bayesian_network_join_tree& b ) { a.swap(b); } } // ---------------------------------------------------------------------------------------- #endif // DLIB_BAYES_UTILs_ ================================================ FILE: dlib/bayes_utils/bayes_utils_abstract.h ================================================ // Copyright (C) 2007 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_BAYES_UTILs_ABSTRACT_ #ifdef DLIB_BAYES_UTILs_ABSTRACT_ #include "../algs.h" #include "../noncopyable.h" #include "../interfaces/enumerable.h" #include "../interfaces/map_pair.h" #include "../serialize.h" #include namespace dlib { // ---------------------------------------------------------------------------------------- class assignment : public enumerable > { /*! INITIAL VALUE - size() == 0 ENUMERATION ORDER The enumerator will iterate over the entries in the assignment in ascending order according to index values. (i.e. the elements are enumerated in sorted order according to the value of their keys) WHAT THIS OBJECT REPRESENTS This object models an assignment of random variables to particular values. It is used with the joint_probability_table and conditional_probability_table objects to represent assignments of various random variables to actual values. So for example, if you had a joint_probability_table that represented the following table: P(A = 0, B = 0) = 0.2 P(A = 0, B = 1) = 0.3 P(A = 1, B = 0) = 0.1 P(A = 1, B = 1) = 0.4 Also lets define an enum so we have concrete index numbers for A and B enum { A = 0, B = 1}; Then you could query the value of P(A=1, B=0) as follows: assignment a; a.set(A, 1); a.set(B, 0); // and now it is the case that: table.probability(a) == 0.1 a[A] == 1 a[B] == 0 Also note that when enumerating the elements of an assignment object the key() refers to the index and the value() refers to the value at that index. For example: // assume a is an assignment object a.reset(); while (a.move_next()) { // in this loop it is always the case that: // a[a.element().key()] == a.element().value() } !*/ public: assignment( ); /*! ensures - this object is properly initialized !*/ assignment( const assignment& a ); /*! ensures - #*this is a copy of a !*/ assignment& operator = ( const assignment& rhs ); /*! ensures - #*this is a copy of rhs - returns *this !*/ void clear( ); /*! ensures - this object has been returned to its initial value !*/ bool operator < ( const assignment& item ) const; /*! ensures - The exact functioning of this operator is undefined. The only guarantee is that it establishes a total ordering on all possible assignment objects. In other words, this operator makes it so that you can use assignment objects in the associative containers but otherwise isn't of any particular use. !*/ bool has_index ( unsigned long idx ) const; /*! ensures - if (this assignment object has an entry for index idx) then - returns true - else - returns false !*/ void add ( unsigned long idx, unsigned long value = 0 ); /*! requires - has_index(idx) == false ensures - #has_index(idx) == true - #(*this)[idx] == value !*/ void remove ( unsigned long idx ); /*! requires - has_index(idx) == true ensures - #has_index(idx) == false !*/ unsigned long& operator[] ( const long idx ); /*! requires - has_index(idx) == true ensures - returns a reference to the value associated with index idx !*/ const unsigned long& operator[] ( const long idx ) const; /*! requires - has_index(idx) == true ensures - returns a const reference to the value associated with index idx !*/ void swap ( assignment& item ); /*! ensures - swaps *this and item !*/ }; inline void swap ( assignment& a, assignment& b ) { a.swap(b); } /*! provides a global swap !*/ std::ostream& operator << ( std::ostream& out, const assignment& a ); /*! ensures - writes a to the given output stream in the following format: (index1:value1, index2:value2, ..., indexN:valueN) !*/ void serialize ( const assignment& item, std::ostream& out ); /*! provides serialization support !*/ void deserialize ( assignment& item, std::istream& in ); /*! provides deserialization support !*/ // ------------------------------------------------------------------------ class joint_probability_table : public enumerable > { /*! INITIAL VALUE - size() == 0 ENUMERATION ORDER The enumerator will iterate over the entries in the probability table in no particular order but they will all be visited. WHAT THIS OBJECT REPRESENTS This object models a joint probability table. That is, it models the function p(X). So this object models the probability of a particular set of variables (referred to as X). !*/ public: joint_probability_table( ); /*! ensures - this object is properly initialized !*/ joint_probability_table ( const joint_probability_table& t ); /*! ensures - this object is a copy of t !*/ void clear( ); /*! ensures - this object has its initial value !*/ joint_probability_table& operator= ( const joint_probability_table& rhs ); /*! ensures - this object is a copy of rhs - returns a reference to *this !*/ bool has_entry_for ( const assignment& a ) const; /*! ensures - if (this joint_probability_table has an entry for p(X = a)) then - returns true - else - returns false !*/ void set_probability ( const assignment& a, double p ); /*! requires - 0 <= p <= 1 ensures - if (has_entry_for(a) == false) then - #size() == size() + 1 - #probability(a) == p - #has_entry_for(a) == true !*/ void add_probability ( const assignment& a, double p ); /*! requires - 0 <= p <= 1 ensures - if (has_entry_for(a) == false) then - #size() == size() + 1 - #probability(a) == p - else - #probability(a) == min(probability(a) + p, 1.0) (i.e. does a saturating add) - #has_entry_for(a) == true !*/ const double probability ( const assignment& a ) const; /*! ensures - returns the probability p(X == a) !*/ template < typename T > void marginalize ( const T& vars, joint_probability_table& output_table ) const; /*! requires - T is an implementation of set/set_kernel_abstract.h ensures - marginalizes *this by summing over all variables not in vars. The result is stored in output_table. !*/ void marginalize ( const unsigned long var, joint_probability_table& output_table ) const; /*! ensures - is identical to calling the above marginalize() function with a set that contains only var. Or in other words, performs a marginalization with just one variable var. So that output_table will contain a table giving the marginal probability of var all by itself. !*/ void normalize ( ); /*! ensures - let sum == the sum of all the probabilities in this table - after normalize() has finished it will be the case that the sum of all the entries in this table is 1.0. This is accomplished by dividing all the entries by the sum described above. !*/ void swap ( joint_probability_table& item ); /*! ensures - swaps *this and item !*/ }; inline void swap ( joint_probability_table& a, joint_probability_table& b ) { a.swap(b); } /*! provides a global swap !*/ void serialize ( const joint_probability_table& item, std::ostream& out ); /*! provides serialization support !*/ void deserialize ( joint_probability_table& item, std::istream& in ); /*! provides deserialization support !*/ // ---------------------------------------------------------------------------------------- class conditional_probability_table : noncopyable { /*! INITIAL VALUE - num_values() == 0 - has_value_for(x, y) == false for all values of x and y WHAT THIS OBJECT REPRESENTS This object models a conditional probability table. That is, it models the function p( X | parents). So this object models the conditional probability of a particular variable (referred to as X) given another set of variables (referred to as parents). !*/ public: conditional_probability_table( ); /*! ensures - this object is properly initialized !*/ void clear( ); /*! ensures - this object has its initial value !*/ void empty_table ( ); /*! ensures - for all possible v and p: - #has_entry_for(v,p) == false (i.e. this function clears out the table when you call it but doesn't change the value of num_values()) !*/ void set_num_values ( unsigned long num ); /*! ensures - #num_values() == num - for all possible v and p: - #has_entry_for(v,p) == false (i.e. this function clears out the table when you call it) !*/ unsigned long num_values ( ) const; /*! ensures - This object models the probability table p(X | parents). This function returns the number of values X can take on. !*/ bool has_entry_for ( unsigned long value, const assignment& ps ) const; /*! ensures - if (this conditional_probability_table has an entry for p(X = value, parents = ps)) then - returns true - else - returns false !*/ void set_probability ( unsigned long value, const assignment& ps, double p ); /*! requires - value < num_values() - 0 <= p <= 1 ensures - #probability(ps, value) == p - #has_entry_for(value, ps) == true !*/ double probability( unsigned long value, const assignment& ps ) const; /*! requires - value < num_values() - has_entry_for(value, ps) == true ensures - returns the probability p( X = value | parents = ps). !*/ void swap ( conditional_probability_table& item ); /*! ensures - swaps *this and item !*/ }; inline void swap ( conditional_probability_table& a, conditional_probability_table& b ) { a.swap(b); } /*! provides a global swap !*/ void serialize ( const conditional_probability_table& item, std::ostream& out ); /*! provides serialization support !*/ void deserialize ( conditional_probability_table& item, std::istream& in ); /*! provides deserialization support !*/ // ------------------------------------------------------------------------ // ------------------------------------------------------------------------ // ------------------------------------------------------------------------ class bayes_node : noncopyable { /*! INITIAL VALUE - is_evidence() == false - value() == 0 - table().num_values() == 0 WHAT THIS OBJECT REPRESENTS This object represents a node in a bayesian network. It is intended to be used inside the dlib::directed_graph object to represent bayesian networks. !*/ public: bayes_node ( ); /*! ensures - this object is properly initialized !*/ unsigned long value ( ) const; /*! ensures - returns the current value of this node !*/ void set_value ( unsigned long new_value ); /*! requires - new_value < table().num_values() ensures - #value() == new_value !*/ conditional_probability_table& table ( ); /*! ensures - returns a reference to the conditional_probability_table associated with this node !*/ const conditional_probability_table& table ( ) const; /*! ensures - returns a const reference to the conditional_probability_table associated with this node. !*/ bool is_evidence ( ) const; /*! ensures - if (this is an evidence node) then - returns true - else - returns false !*/ void set_as_nonevidence ( ); /*! ensures - #is_evidence() == false !*/ void set_as_evidence ( ); /*! ensures - #is_evidence() == true !*/ void swap ( bayes_node& item ); /*! ensures - swaps *this and item !*/ }; inline void swap ( bayes_node& a, bayes_node& b ) { a.swap(b); } /*! provides a global swap !*/ void serialize ( const bayes_node& item, std::ostream& out ); /*! provides serialization support !*/ void deserialize ( bayes_node& item, std::istream& in ); /*! provides deserialization support !*/ // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- /* The following group of functions are convenience functions for manipulating bayes_node objects while they are inside a directed_graph. These functions also have additional requires clauses that, in debug mode, will protect you from attempts to manipulate a bayesian network in an inappropriate way. */ namespace bayes_node_utils { template < typename T > void set_node_value ( T& bn, unsigned long n, unsigned long val ); /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node - n < bn.number_of_nodes() - val < node_num_values(bn, n) ensures - #bn.node(n).data.value() = val !*/ // ------------------------------------------------------------------------------------ template < typename T > unsigned long node_value ( const T& bn, unsigned long n ); /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node - n < bn.number_of_nodes() ensures - returns bn.node(n).data.value() !*/ // ------------------------------------------------------------------------------------ template < typename T > bool node_is_evidence ( const T& bn, unsigned long n ); /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node - n < bn.number_of_nodes() ensures - returns bn.node(n).data.is_evidence() !*/ // ------------------------------------------------------------------------------------ template < typename T > void set_node_as_evidence ( T& bn, unsigned long n ); /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node - n < bn.number_of_nodes() ensures - executes: bn.node(n).data.set_as_evidence() !*/ // ------------------------------------------------------------------------------------ template < typename T > void set_node_as_nonevidence ( T& bn, unsigned long n ); /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node - n < bn.number_of_nodes() ensures - executes: bn.node(n).data.set_as_nonevidence() !*/ // ------------------------------------------------------------------------------------ template < typename T > void set_node_num_values ( T& bn, unsigned long n, unsigned long num ); /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node - n < bn.number_of_nodes() ensures - #bn.node(n).data.table().num_values() == num (i.e. sets the number of different values this node can take) !*/ // ------------------------------------------------------------------------------------ template < typename T > unsigned long node_num_values ( const T& bn, unsigned long n ); /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node - n < bn.number_of_nodes() ensures - returns bn.node(n).data.table().num_values() (i.e. returns the number of different values this node can take) !*/ // ------------------------------------------------------------------------------------ template < typename T > const double node_probability ( const T& bn, unsigned long n, unsigned long value, const assignment& parents ); /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node - n < bn.number_of_nodes() - value < node_num_values(bn,n) - parents.size() == bn.node(n).number_of_parents() - if (parents.has_index(x)) then - bn.has_edge(x, n) - parents[x] < node_num_values(bn,x) ensures - returns bn.node(n).data.table().probability(value, parents) (i.e. returns the probability of node n having the given value when its parents have the given assignment) !*/ // ------------------------------------------------------------------------------------ template < typename T > const double set_node_probability ( const T& bn, unsigned long n, unsigned long value, const assignment& parents, double p ); /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node - n < bn.number_of_nodes() - value < node_num_values(bn,n) - 0 <= p <= 1 - parents.size() == bn.node(n).number_of_parents() - if (parents.has_index(x)) then - bn.has_edge(x, n) - parents[x] < node_num_values(bn,x) ensures - #bn.node(n).data.table().probability(value, parents) == p (i.e. sets the probability of node n having the given value when its parents have the given assignment to the probability p) !*/ // ------------------------------------------------------------------------------------ template const assignment node_first_parent_assignment ( const T& bn, unsigned long n ); /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node - n < bn.number_of_nodes() ensures - returns an assignment A such that: - A.size() == bn.node(n).number_of_parents() - if (P is a parent of bn.node(n)) then - A.has_index(P) - A[P] == 0 - I.e. this function returns an assignment that contains all the parents of the given node. Also, all the values of each parent in the assignment is set to zero. !*/ // ------------------------------------------------------------------------------------ template bool node_next_parent_assignment ( const T& bn, unsigned long n, assignment& A ); /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node - n < bn.number_of_nodes() - A.size() == bn.node(n).number_of_parents() - if (A.has_index(x)) then - bn.has_edge(x, n) - A[x] < node_num_values(bn,x) ensures - The behavior of this function is defined by the following code: assignment a(node_first_parent_assignment(bn,n); do { // this loop loops over all possible parent assignments // of the node bn.node(n). Each time through the loop variable a // will be the next assignment. } while (node_next_parent_assignment(bn,n,a)) !*/ // ------------------------------------------------------------------------------------ template bool node_cpt_filled_out ( const T& bn, unsigned long n ); /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node - n < bn.number_of_nodes() ensures - if (the conditional_probability_table bn.node(n).data.table() is fully filled out for this node) then - returns true - This means that each parent assignment for the given node along with all possible values of this node shows up in the table. - It also means that all the probabilities conditioned on the same parent assignment sum to 1.0 - else - returns false !*/ } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- class bayesian_network_gibbs_sampler : noncopyable { /*! INITIAL VALUE This object has no state WHAT THIS OBJECT REPRESENTS This object performs Markov Chain Monte Carlo sampling of a bayesian network using the Gibbs sampling technique. Note that this object is limited to only bayesian networks that don't contain deterministic nodes. That is, incorrect results may be computed if this object is used when the bayesian network contains any nodes that have a probability of 1 in their conditional probability tables for any event. So don't use this object for networks with deterministic nodes. !*/ public: bayesian_network_gibbs_sampler ( ); /*! ensures - this object is properly initialized !*/ template < typename T > void sample_graph ( T& bn ) /*! requires - T is an implementation of directed_graph/directed_graph_kernel_abstract.h - T::type == bayes_node ensures - modifies randomly (via the Gibbs sampling technique) samples all the nodes in the network and updates their values with the newly sampled values !*/ }; // ---------------------------------------------------------------------------------------- class bayesian_network_join_tree : noncopyable { /*! WHAT THIS OBJECT REPRESENTS This object represents an implementation of the join tree algorithm for inference in bayesian networks. It doesn't have any mutable state. To you use you just give it a directed_graph that contains a bayesian network and a graph object that contains that networks corresponding join tree. Then you may query this object to determine the probabilities of any variables in the original bayesian network. !*/ public: template < typename bn_type, typename join_tree_type > bayesian_network_join_tree ( const bn_type& bn, const join_tree_type& join_tree ); /*! requires - bn_type is an implementation of directed_graph/directed_graph_kernel_abstract.h - bn_type::type == bayes_node - join_tree_type is an implementation of graph/graph_kernel_abstract.h - join_tree_type::type is an implementation of set/set_compare_abstract.h and this set type contains unsigned long objects. - join_tree_type::edge_type is an implementation of set/set_compare_abstract.h and this set type contains unsigned long objects. - is_join_tree(bn, join_tree) == true - bn == a valid bayesian network with all its conditional probability tables filled out - for all valid n: - node_cpt_filled_out(bn,n) == true - graph_contains_length_one_cycle(bn) == false - graph_is_connected(bn) == true - bn.number_of_nodes() > 0 ensures - this object is properly initialized !*/ unsigned long number_of_nodes ( ) const; /*! ensures - returns the number of nodes in the bayesian network that this object was instantiated from. !*/ const matrix probability( unsigned long idx ) const; /*! requires - idx < number_of_nodes() ensures - returns the probability distribution for the node with index idx that was in the bayesian network that *this was instantiated from. Let D represent this distribution, then: - D.nc() == the number of values the node idx ranges over - D.nr() == 1 - D(i) == the probability of node idx taking on the value i !*/ void swap ( bayesian_network_join_tree& item ); /*! ensures - swaps *this with item !*/ }; inline void swap ( bayesian_network_join_tree& a, bayesian_network_join_tree& b ) { a.swap(b); } /*! provides a global swap !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_BAYES_UTILs_ABSTRACT_ ================================================ FILE: dlib/bayes_utils.h ================================================ // Copyright (C) 2007 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BAYES_UTILs_H_ #define DLIB_BAYES_UTILs_H_ #include "bayes_utils/bayes_utils.h" #endif // DLIB_BAYES_UTILs_H_ ================================================ FILE: dlib/bigint/bigint_kernel_1.cpp ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BIGINT_KERNEL_1_CPp_ #define DLIB_BIGINT_KERNEL_1_CPp_ #include "bigint_kernel_1.h" #include namespace dlib { // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member/friend function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- bigint_kernel_1:: bigint_kernel_1 ( ) : slack(25), data(new data_record(slack)) {} // ---------------------------------------------------------------------------------------- bigint_kernel_1:: bigint_kernel_1 ( uint32 value ) : slack(25), data(new data_record(slack)) { *(data->number) = static_cast(value&0xFFFF); *(data->number+1) = static_cast((value>>16)&0xFFFF); if (*(data->number+1) != 0) data->digits_used = 2; } // ---------------------------------------------------------------------------------------- bigint_kernel_1:: bigint_kernel_1 ( const bigint_kernel_1& item ) : slack(25), data(item.data) { data->references += 1; } // ---------------------------------------------------------------------------------------- bigint_kernel_1:: ~bigint_kernel_1 ( ) { if (data->references == 1) { delete data; } else { data->references -= 1; } } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 bigint_kernel_1:: operator+ ( const bigint_kernel_1& rhs ) const { data_record* temp = new data_record ( std::max(rhs.data->digits_used,data->digits_used) + slack ); long_add(data,rhs.data,temp); return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator+= ( const bigint_kernel_1& rhs ) { // if there are other references to our data if (data->references != 1) { data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); data->references -= 1; long_add(data,rhs.data,temp); data = temp; } // if data is not big enough for the result else if (data->size <= std::max(data->digits_used,rhs.data->digits_used)) { data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); long_add(data,rhs.data,temp); delete data; data = temp; } // there is enough size and no references else { long_add(data,rhs.data,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 bigint_kernel_1:: operator- ( const bigint_kernel_1& rhs ) const { data_record* temp = new data_record ( data->digits_used + slack ); long_sub(data,rhs.data,temp); return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator-= ( const bigint_kernel_1& rhs ) { // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; long_sub(data,rhs.data,temp); data = temp; } else { long_sub(data,rhs.data,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 bigint_kernel_1:: operator* ( const bigint_kernel_1& rhs ) const { data_record* temp = new data_record ( data->digits_used + rhs.data->digits_used + slack ); long_mul(data,rhs.data,temp); return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator*= ( const bigint_kernel_1& rhs ) { // create a data_record to store the result of the multiplication in data_record* temp = new data_record(rhs.data->digits_used+data->digits_used+slack); long_mul(data,rhs.data,temp); // if there are other references to data if (data->references != 1) { data->references -= 1; } else { delete data; } data = temp; return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 bigint_kernel_1:: operator/ ( const bigint_kernel_1& rhs ) const { data_record* temp = new data_record(data->digits_used+slack); data_record* remainder; try { remainder = new data_record(data->digits_used+slack); } catch (...) { delete temp; throw; } long_div(data,rhs.data,temp,remainder); delete remainder; return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator/= ( const bigint_kernel_1& rhs ) { data_record* temp = new data_record(data->digits_used+slack); data_record* remainder; try { remainder = new data_record(data->digits_used+slack); } catch (...) { delete temp; throw; } long_div(data,rhs.data,temp,remainder); // check if there are other references to data if (data->references != 1) { data->references -= 1; } // if there are no references to data then it must be deleted else { delete data; } data = temp; delete remainder; return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 bigint_kernel_1:: operator% ( const bigint_kernel_1& rhs ) const { data_record* temp = new data_record(data->digits_used+slack); data_record* remainder; try { remainder = new data_record(data->digits_used+slack); } catch (...) { delete temp; throw; } long_div(data,rhs.data,temp,remainder); delete temp; return bigint_kernel_1(remainder,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator%= ( const bigint_kernel_1& rhs ) { data_record* temp = new data_record(data->digits_used+slack); data_record* remainder; try { remainder = new data_record(data->digits_used+slack); } catch (...) { delete temp; throw; } long_div(data,rhs.data,temp,remainder); // check if there are other references to data if (data->references != 1) { data->references -= 1; } // if there are no references to data then it must be deleted else { delete data; } data = remainder; delete temp; return *this; } // ---------------------------------------------------------------------------------------- bool bigint_kernel_1:: operator < ( const bigint_kernel_1& rhs ) const { return is_less_than(data,rhs.data); } // ---------------------------------------------------------------------------------------- bool bigint_kernel_1:: operator == ( const bigint_kernel_1& rhs ) const { return is_equal_to(data,rhs.data); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator= ( const bigint_kernel_1& rhs ) { if (this == &rhs) return *this; // if we have the only reference to our data then delete it if (data->references == 1) { delete data; data = rhs.data; data->references += 1; } else { data->references -= 1; data = rhs.data; data->references += 1; } return *this; } // ---------------------------------------------------------------------------------------- std::ostream& operator<< ( std::ostream& out_, const bigint_kernel_1& rhs ) { std::ostream out(out_.rdbuf()); typedef bigint_kernel_1 bigint; bigint::data_record* temp = new bigint::data_record(*rhs.data,0); // get a char array big enough to hold the number in ascii format char* str; try { str = new char[(rhs.data->digits_used)*5+10]; } catch (...) { delete temp; throw; } char* str_start = str; str += (rhs.data->digits_used)*5+9; *str = 0; --str; uint16 remainder; rhs.short_div(temp,10000,temp,remainder); // pull the digits out of remainder char a = remainder % 10 + '0'; remainder /= 10; char b = remainder % 10 + '0'; remainder /= 10; char c = remainder % 10 + '0'; remainder /= 10; char d = remainder % 10 + '0'; remainder /= 10; *str = a; --str; *str = b; --str; *str = c; --str; *str = d; --str; // keep looping until temp represents zero while (temp->digits_used != 1 || *(temp->number) != 0) { rhs.short_div(temp,10000,temp,remainder); // pull the digits out of remainder char a = remainder % 10 + '0'; remainder /= 10; char b = remainder % 10 + '0'; remainder /= 10; char c = remainder % 10 + '0'; remainder /= 10; char d = remainder % 10 + '0'; remainder /= 10; *str = a; --str; *str = b; --str; *str = c; --str; *str = d; --str; } // throw away and extra leading zeros ++str; if (*str == '0') ++str; if (*str == '0') ++str; if (*str == '0') ++str; out << str; delete [] str_start; delete temp; return out_; } // ---------------------------------------------------------------------------------------- std::istream& operator>> ( std::istream& in_, bigint_kernel_1& rhs ) { std::istream in(in_.rdbuf()); // ignore any leading whitespaces while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n') { in.get(); } // if the first digit is not an integer then this is an error if ( !(in.peek() >= '0' && in.peek() <= '9')) { in_.clear(std::ios::failbit); return in_; } int num_read; bigint_kernel_1 temp; do { // try to get 4 chars from in num_read = 1; char a = 0; char b = 0; char c = 0; char d = 0; if (in.peek() >= '0' && in.peek() <= '9') { num_read *= 10; a = in.get(); } if (in.peek() >= '0' && in.peek() <= '9') { num_read *= 10; b = in.get(); } if (in.peek() >= '0' && in.peek() <= '9') { num_read *= 10; c = in.get(); } if (in.peek() >= '0' && in.peek() <= '9') { num_read *= 10; d = in.get(); } // merge the for digits into an uint16 uint16 num = 0; if (a != 0) { num = a - '0'; } if (b != 0) { num *= 10; num += b - '0'; } if (c != 0) { num *= 10; num += c - '0'; } if (d != 0) { num *= 10; num += d - '0'; } if (num_read != 1) { // shift the digits in temp left by the number of new digits we just read temp *= num_read; // add in new digits temp += num; } } while (num_read == 10000); rhs = temp; return in_; } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 operator+ ( uint16 lhs, const bigint_kernel_1& rhs ) { typedef bigint_kernel_1 bigint; bigint::data_record* temp = new bigint::data_record (rhs.data->digits_used+rhs.slack); rhs.short_add(rhs.data,lhs,temp); return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 operator+ ( const bigint_kernel_1& lhs, uint16 rhs ) { typedef bigint_kernel_1 bigint; bigint::data_record* temp = new bigint::data_record (lhs.data->digits_used+lhs.slack); lhs.short_add(lhs.data,rhs,temp); return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator+= ( uint16 rhs ) { // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; short_add(data,rhs,temp); data = temp; } // or if we need to enlarge data then do so else if (data->digits_used == data->size) { data_record* temp = new data_record(data->digits_used+slack); short_add(data,rhs,temp); delete data; data = temp; } // or if there is plenty of space and no references else { short_add(data,rhs,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 operator- ( uint16 lhs, const bigint_kernel_1& rhs ) { typedef bigint_kernel_1 bigint; bigint::data_record* temp = new bigint::data_record(rhs.slack); *(temp->number) = lhs - *(rhs.data->number); return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 operator- ( const bigint_kernel_1& lhs, uint16 rhs ) { typedef bigint_kernel_1 bigint; bigint::data_record* temp = new bigint::data_record (lhs.data->digits_used+lhs.slack); lhs.short_sub(lhs.data,rhs,temp); return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator-= ( uint16 rhs ) { // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; short_sub(data,rhs,temp); data = temp; } else { short_sub(data,rhs,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 operator* ( uint16 lhs, const bigint_kernel_1& rhs ) { typedef bigint_kernel_1 bigint; bigint::data_record* temp = new bigint::data_record (rhs.data->digits_used+rhs.slack); rhs.short_mul(rhs.data,lhs,temp); return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 operator* ( const bigint_kernel_1& lhs, uint16 rhs ) { typedef bigint_kernel_1 bigint; bigint::data_record* temp = new bigint::data_record (lhs.data->digits_used+lhs.slack); lhs.short_mul(lhs.data,rhs,temp); return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator*= ( uint16 rhs ) { // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; short_mul(data,rhs,temp); data = temp; } // or if we need to enlarge data else if (data->digits_used == data->size) { data_record* temp = new data_record(data->digits_used+slack); short_mul(data,rhs,temp); delete data; data = temp; } else { short_mul(data,rhs,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 operator/ ( uint16 lhs, const bigint_kernel_1& rhs ) { typedef bigint_kernel_1 bigint; bigint::data_record* temp = new bigint::data_record(rhs.slack); // if rhs might not be bigger than lhs if (rhs.data->digits_used == 1) { *(temp->number) = lhs/ *(rhs.data->number); } return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 operator/ ( const bigint_kernel_1& lhs, uint16 rhs ) { typedef bigint_kernel_1 bigint; bigint::data_record* temp = new bigint::data_record (lhs.data->digits_used+lhs.slack); uint16 remainder; lhs.short_div(lhs.data,rhs,temp,remainder); return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator/= ( uint16 rhs ) { uint16 remainder; // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; short_div(data,rhs,temp,remainder); data = temp; } else { short_div(data,rhs,data,remainder); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 operator% ( uint16 lhs, const bigint_kernel_1& rhs ) { typedef bigint_kernel_1 bigint; // temp is zero by default bigint::data_record* temp = new bigint::data_record(rhs.slack); if (rhs.data->digits_used == 1) { // if rhs is just an uint16 inside then perform the modulus *(temp->number) = lhs % *(rhs.data->number); } else { // if rhs is bigger than lhs then the answer is lhs *(temp->number) = lhs; } return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 operator% ( const bigint_kernel_1& lhs, uint16 rhs ) { typedef bigint_kernel_1 bigint; bigint::data_record* temp = new bigint::data_record(lhs.data->digits_used+lhs.slack); uint16 remainder; lhs.short_div(lhs.data,rhs,temp,remainder); temp->digits_used = 1; *(temp->number) = remainder; return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator%= ( uint16 rhs ) { uint16 remainder; // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; short_div(data,rhs,temp,remainder); data = temp; } else { short_div(data,rhs,data,remainder); } data->digits_used = 1; *(data->number) = remainder; return *this; } // ---------------------------------------------------------------------------------------- bool operator < ( uint16 lhs, const bigint_kernel_1& rhs ) { return (rhs.data->digits_used > 1 || lhs < *(rhs.data->number) ); } // ---------------------------------------------------------------------------------------- bool operator < ( const bigint_kernel_1& lhs, uint16 rhs ) { return (lhs.data->digits_used == 1 && *(lhs.data->number) < rhs); } // ---------------------------------------------------------------------------------------- bool operator == ( const bigint_kernel_1& lhs, uint16 rhs ) { return (lhs.data->digits_used == 1 && *(lhs.data->number) == rhs); } // ---------------------------------------------------------------------------------------- bool operator == ( uint16 lhs, const bigint_kernel_1& rhs ) { return (rhs.data->digits_used == 1 && *(rhs.data->number) == lhs); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator= ( uint16 rhs ) { // check if there are other references to our data if (data->references != 1) { data->references -= 1; try { data = new data_record(slack); } catch (...) { data->references += 1; throw; } } else { data->digits_used = 1; } *(data->number) = rhs; return *this; } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator++ ( ) { // if there are other references to this data then make a copy of it if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; increment(data,temp); data = temp; } // or if we need to enlarge data then do so else if (data->digits_used == data->size) { data_record* temp = new data_record(data->digits_used+slack); increment(data,temp); delete data; data = temp; } else { increment(data,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 bigint_kernel_1:: operator++ ( int ) { data_record* temp; // this is the copy of temp we will return in the end data_record* temp2 = new data_record(data->digits_used+slack); increment(data,temp2); temp = data; data = temp2; return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_1& bigint_kernel_1:: operator-- ( ) { // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; decrement(data,temp); data = temp; } else { decrement(data,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_1 bigint_kernel_1:: operator-- ( int ) { data_record* temp; // this is the copy of temp we will return in the end data_record* temp2 = new data_record(data->digits_used+slack); decrement(data,temp2); temp = data; data = temp2; return bigint_kernel_1(temp,0); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // private member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- void bigint_kernel_1:: short_add ( const data_record* data, uint16 value, data_record* result ) const { // put value into the carry part of temp uint32 temp = value; temp <<= 16; const uint16* number = data->number; const uint16* end = number + data->digits_used; // one past the end of number uint16* r = result->number; while (number != end) { // add *number and the current carry temp = *number + (temp>>16); // put the low word of temp into *r *r = static_cast(temp & 0xFFFF); ++number; ++r; } // if there is a final carry if ((temp>>16) != 0) { result->digits_used = data->digits_used + 1; // store the carry in the most significant digit of the result *r = static_cast(temp>>16); } else { result->digits_used = data->digits_used; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_1:: short_sub ( const data_record* data, uint16 value, data_record* result ) const { const uint16* number = data->number; const uint16* end = number + data->digits_used - 1; uint16* r = result->number; uint32 temp = *number - value; // put the low word of temp into *data *r = static_cast(temp & 0xFFFF); while (number != end) { ++number; ++r; // subtract the carry from *number temp = *number - (temp>>31); // put the low word of temp into *r *r = static_cast(temp & 0xFFFF); } // if we lost a digit in the subtraction if (*r == 0) { if (data->digits_used == 1) result->digits_used = 1; else result->digits_used = data->digits_used - 1; } else { result->digits_used = data->digits_used; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_1:: short_mul ( const data_record* data, uint16 value, data_record* result ) const { uint32 temp = 0; const uint16* number = data->number; uint16* r = result->number; const uint16* end = r + data->digits_used; while ( r != end) { // multiply *data and value and add in the carry temp = *number*(uint32)value + (temp>>16); // put the low word of temp into *data *r = static_cast(temp & 0xFFFF); ++number; ++r; } // if there is a final carry if ((temp>>16) != 0) { result->digits_used = data->digits_used + 1; // put the final carry into the most significant digit of the result *r = static_cast(temp>>16); } else { result->digits_used = data->digits_used; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_1:: short_div ( const data_record* data, uint16 value, data_record* result, uint16& rem ) const { uint16 remainder = 0; uint32 temp; const uint16* number = data->number + data->digits_used - 1; const uint16* end = number - data->digits_used; uint16* r = result->number + data->digits_used - 1; // if we are losing a digit in this division if (*number < value) { if (data->digits_used == 1) result->digits_used = 1; else result->digits_used = data->digits_used - 1; } else { result->digits_used = data->digits_used; } // perform the actual division while (number != end) { temp = *number + (((uint32)remainder)<<16); *r = static_cast(temp/value); remainder = static_cast(temp%value); --number; --r; } rem = remainder; } // ---------------------------------------------------------------------------------------- void bigint_kernel_1:: long_add ( const data_record* lhs, const data_record* rhs, data_record* result ) const { // put value into the carry part of temp uint32 temp=0; uint16* min_num; // the number with the least digits used uint16* max_num; // the number with the most digits used uint16* min_end; // one past the end of min_num uint16* max_end; // one past the end of max_num uint16* r = result->number; uint32 max_digits_used; if (lhs->digits_used < rhs->digits_used) { max_digits_used = rhs->digits_used; min_num = lhs->number; max_num = rhs->number; min_end = min_num + lhs->digits_used; max_end = max_num + rhs->digits_used; } else { max_digits_used = lhs->digits_used; min_num = rhs->number; max_num = lhs->number; min_end = min_num + rhs->digits_used; max_end = max_num + lhs->digits_used; } while (min_num != min_end) { // add *min_num, *max_num and the current carry temp = *min_num + *max_num + (temp>>16); // put the low word of temp into *r *r = static_cast(temp & 0xFFFF); ++min_num; ++max_num; ++r; } while (max_num != max_end) { // add *max_num and the current carry temp = *max_num + (temp>>16); // put the low word of temp into *r *r = static_cast(temp & 0xFFFF); ++max_num; ++r; } // check if there was a final carry if ((temp>>16) != 0) { result->digits_used = max_digits_used + 1; // put the carry into the most significant digit in the result *r = static_cast(temp>>16); } else { result->digits_used = max_digits_used; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_1:: long_sub ( const data_record* lhs, const data_record* rhs, data_record* result ) const { const uint16* number1 = lhs->number; const uint16* number2 = rhs->number; const uint16* end = number2 + rhs->digits_used; uint16* r = result->number; uint32 temp =0; while (number2 != end) { // subtract *number2 from *number1 and then subtract any carry temp = *number1 - *number2 - (temp>>31); // put the low word of temp into *r *r = static_cast(temp & 0xFFFF); ++number1; ++number2; ++r; } end = lhs->number + lhs->digits_used; while (number1 != end) { // subtract the carry from *number1 temp = *number1 - (temp>>31); // put the low word of temp into *r *r = static_cast(temp & 0xFFFF); ++number1; ++r; } result->digits_used = lhs->digits_used; // adjust the number of digits used appropriately --r; while (*r == 0 && result->digits_used > 1) { --r; --result->digits_used; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_1:: long_div ( const data_record* lhs, const data_record* rhs, data_record* result, data_record* remainder ) const { // zero result result->digits_used = 1; *(result->number) = 0; uint16* a; uint16* b; uint16* end; // copy lhs into remainder remainder->digits_used = lhs->digits_used; a = remainder->number; end = a + remainder->digits_used; b = lhs->number; while (a != end) { *a = *b; ++a; ++b; } // if rhs is bigger than lhs then result == 0 and remainder == lhs // so then we can quit right now if (is_less_than(lhs,rhs)) { return; } // make a temporary number data_record temp(lhs->digits_used + slack); // shift rhs left until it is one shift away from being larger than lhs and // put the number of left shifts necessary into shifts uint32 shifts; shifts = (lhs->digits_used - rhs->digits_used) * 16; shift_left(rhs,&temp,shifts); // while (lhs > temp) while (is_less_than(&temp,lhs)) { shift_left(&temp,&temp,1); ++shifts; } // make sure lhs isn't smaller than temp while (is_less_than(lhs,&temp)) { shift_right(&temp,&temp); --shifts; } // we want to execute the loop shifts +1 times ++shifts; while (shifts != 0) { shift_left(result,result,1); // if (temp <= remainder) if (!is_less_than(remainder,&temp)) { long_sub(remainder,&temp,remainder); // increment result uint16* r = result->number; uint16* end = r + result->digits_used; while (true) { ++(*r); // if there was no carry then we are done if (*r != 0) break; ++r; // if we hit the end of r and there is still a carry then // the next digit of r is 1 and there is one more digit used if (r == end) { *r = 1; ++(result->digits_used); break; } } } shift_right(&temp,&temp); --shifts; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_1:: long_mul ( const data_record* lhs, const data_record* rhs, data_record* result ) const { // make result be zero result->digits_used = 1; *(result->number) = 0; const data_record* aa; const data_record* bb; if (lhs->digits_used < rhs->digits_used) { // make copies of lhs and rhs and give them an appropriate amount of // extra memory so there won't be any overflows aa = lhs; bb = rhs; } else { // make copies of lhs and rhs and give them an appropriate amount of // extra memory so there won't be any overflows aa = rhs; bb = lhs; } // this is where we actually copy lhs and rhs data_record b(*bb,aa->digits_used+slack); // the larger(approximately) of lhs and rhs uint32 shift_value = 0; uint16* anum = aa->number; uint16* end = anum + aa->digits_used; while (anum != end ) { uint16 bit = 0x0001; for (int i = 0; i < 16; ++i) { // if the specified bit of a is 1 if ((*anum & bit) != 0) { shift_left(&b,&b,shift_value); shift_value = 0; long_add(&b,result,result); } ++shift_value; bit <<= 1; } ++anum; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_1:: shift_left ( const data_record* data, data_record* result, uint32 shift_amount ) const { uint32 offset = shift_amount/16; shift_amount &= 0xf; // same as shift_amount %= 16; uint16* r = result->number + data->digits_used + offset; // result uint16* end = data->number; uint16* s = end + data->digits_used; // source const uint32 temp = 16 - shift_amount; *r = (*(--s) >> temp); // set the number of digits used in the result // if the upper bits from *s were zero then don't count this first word if (*r == 0) { result->digits_used = data->digits_used + offset; } else { result->digits_used = data->digits_used + offset + 1; } --r; while (s != end) { *r = ((*s << shift_amount) | ( *(s-1) >> temp)); --r; --s; } *r = *s << shift_amount; // now zero the rest of the result end = result->number; while (r != end) *(--r) = 0; } // ---------------------------------------------------------------------------------------- void bigint_kernel_1:: shift_right ( const data_record* data, data_record* result ) const { uint16* r = result->number; // result uint16* s = data->number; // source uint16* end = s + data->digits_used - 1; while (s != end) { *r = (*s >> 1) | (*(s+1) << 15); ++r; ++s; } *r = *s >> 1; // calculate the new number for digits_used if (*r == 0) { if (data->digits_used != 1) result->digits_used = data->digits_used - 1; else result->digits_used = 1; } else { result->digits_used = data->digits_used; } } // ---------------------------------------------------------------------------------------- bool bigint_kernel_1:: is_less_than ( const data_record* lhs, const data_record* rhs ) const { uint32 lhs_digits_used = lhs->digits_used; uint32 rhs_digits_used = rhs->digits_used; // if lhs is definitely less than rhs if (lhs_digits_used < rhs_digits_used ) return true; // if lhs is definitely greater than rhs else if (lhs_digits_used > rhs_digits_used) return false; else { uint16* end = lhs->number; uint16* l = end + lhs_digits_used; uint16* r = rhs->number + rhs_digits_used; while (l != end) { --l; --r; if (*l < *r) return true; else if (*l > *r) return false; } // at this point we know that they are equal return false; } } // ---------------------------------------------------------------------------------------- bool bigint_kernel_1:: is_equal_to ( const data_record* lhs, const data_record* rhs ) const { // if lhs and rhs are definitely not equal if (lhs->digits_used != rhs->digits_used ) { return false; } else { uint16* l = lhs->number; uint16* r = rhs->number; uint16* end = l + lhs->digits_used; while (l != end) { if (*l != *r) return false; ++l; ++r; } // at this point we know that they are equal return true; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_1:: increment ( const data_record* source, data_record* dest ) const { uint16* s = source->number; uint16* d = dest->number; uint16* end = s + source->digits_used; while (true) { *d = *s + 1; // if there was no carry then break out of the loop if (*d != 0) { dest->digits_used = source->digits_used; // copy the rest of the digits over to d ++d; ++s; while (s != end) { *d = *s; ++d; ++s; } break; } ++s; // if we have hit the end of s and there was a carry up to this point // then just make the next digit 1 and add one to the digits used if (s == end) { ++d; dest->digits_used = source->digits_used + 1; *d = 1; break; } ++d; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_1:: decrement ( const data_record* source, data_record* dest ) const { uint16* s = source->number; uint16* d = dest->number; uint16* end = s + source->digits_used; while (true) { *d = *s - 1; // if there was no carry then break out of the loop if (*d != 0xFFFF) { // if we lost a digit in the subtraction if (*d == 0 && s+1 == end) { if (source->digits_used == 1) dest->digits_used = 1; else dest->digits_used = source->digits_used - 1; } else { dest->digits_used = source->digits_used; } break; } else { ++d; ++s; } } // copy the rest of the digits over to d ++d; ++s; while (s != end) { *d = *s; ++d; ++s; } } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BIGINT_KERNEL_1_CPp_ ================================================ FILE: dlib/bigint/bigint_kernel_1.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BIGINT_KERNEl_1_ #define DLIB_BIGINT_KERNEl_1_ #include "bigint_kernel_abstract.h" #include "../algs.h" #include "../serialize.h" #include "../uintn.h" #include namespace dlib { class bigint_kernel_1 { /*! INITIAL VALUE slack == 25 data->number[0] == 0 data->size == slack data->references == 1 data->digits_used == 1 CONVENTION slack == the number of extra digits placed into the number when it is created. the slack value should never be less than 1 data->number == pointer to an array of data->size uint16s. data represents a string of base 65535 numbers with data[0] being the least significant bit and data[data->digits_used-1] being the most significant NOTE: In the comments I will consider a word to be a 16 bit value data->digits_used == the number of significant digits in the number. data->digits_used tells us the number of used elements in the data->number array so everything beyond data->number[data->digits_used-1] is undefined data->references == the number of bigint_kernel_1 objects which refer to this data_record !*/ struct data_record { explicit data_record( uint32 size_ ) : size(size_), number(new uint16[size_]), references(1), digits_used(1) {*number = 0;} /*! ensures - initializes *this to represent zero !*/ data_record( const data_record& item, uint32 additional_size ) : size(item.digits_used + additional_size), number(new uint16[size]), references(1), digits_used(item.digits_used) { uint16* source = item.number; uint16* dest = number; uint16* end = source + digits_used; while (source != end) { *dest = *source; ++dest; ++source; } } /*! ensures - *this is a copy of item except with size == item.digits_used + additional_size !*/ ~data_record( ) { delete [] number; } const uint32 size; uint16* number; uint32 references; uint32 digits_used; private: // no copy constructor data_record ( data_record&); }; // note that the second parameter is just there // to resolve the ambiguity between this constructor and // bigint_kernel_1(uint32) explicit bigint_kernel_1 ( data_record* data_, int ): slack(25),data(data_) {} /*! ensures - *this is initialized with data_ as its data member !*/ public: bigint_kernel_1 ( ); bigint_kernel_1 ( uint32 value ); bigint_kernel_1 ( const bigint_kernel_1& item ); virtual ~bigint_kernel_1 ( ); const bigint_kernel_1 operator+ ( const bigint_kernel_1& rhs ) const; bigint_kernel_1& operator+= ( const bigint_kernel_1& rhs ); const bigint_kernel_1 operator- ( const bigint_kernel_1& rhs ) const; bigint_kernel_1& operator-= ( const bigint_kernel_1& rhs ); const bigint_kernel_1 operator* ( const bigint_kernel_1& rhs ) const; bigint_kernel_1& operator*= ( const bigint_kernel_1& rhs ); const bigint_kernel_1 operator/ ( const bigint_kernel_1& rhs ) const; bigint_kernel_1& operator/= ( const bigint_kernel_1& rhs ); const bigint_kernel_1 operator% ( const bigint_kernel_1& rhs ) const; bigint_kernel_1& operator%= ( const bigint_kernel_1& rhs ); bool operator < ( const bigint_kernel_1& rhs ) const; bool operator == ( const bigint_kernel_1& rhs ) const; bigint_kernel_1& operator= ( const bigint_kernel_1& rhs ); friend std::ostream& operator<< ( std::ostream& out, const bigint_kernel_1& rhs ); friend std::istream& operator>> ( std::istream& in, bigint_kernel_1& rhs ); bigint_kernel_1& operator++ ( ); const bigint_kernel_1 operator++ ( int ); bigint_kernel_1& operator-- ( ); const bigint_kernel_1 operator-- ( int ); friend const bigint_kernel_1 operator+ ( uint16 lhs, const bigint_kernel_1& rhs ); friend const bigint_kernel_1 operator+ ( const bigint_kernel_1& lhs, uint16 rhs ); bigint_kernel_1& operator+= ( uint16 rhs ); friend const bigint_kernel_1 operator- ( uint16 lhs, const bigint_kernel_1& rhs ); friend const bigint_kernel_1 operator- ( const bigint_kernel_1& lhs, uint16 rhs ); bigint_kernel_1& operator-= ( uint16 rhs ); friend const bigint_kernel_1 operator* ( uint16 lhs, const bigint_kernel_1& rhs ); friend const bigint_kernel_1 operator* ( const bigint_kernel_1& lhs, uint16 rhs ); bigint_kernel_1& operator*= ( uint16 rhs ); friend const bigint_kernel_1 operator/ ( uint16 lhs, const bigint_kernel_1& rhs ); friend const bigint_kernel_1 operator/ ( const bigint_kernel_1& lhs, uint16 rhs ); bigint_kernel_1& operator/= ( uint16 rhs ); friend const bigint_kernel_1 operator% ( uint16 lhs, const bigint_kernel_1& rhs ); friend const bigint_kernel_1 operator% ( const bigint_kernel_1& lhs, uint16 rhs ); bigint_kernel_1& operator%= ( uint16 rhs ); friend bool operator < ( uint16 lhs, const bigint_kernel_1& rhs ); friend bool operator < ( const bigint_kernel_1& lhs, uint16 rhs ); friend bool operator == ( const bigint_kernel_1& lhs, uint16 rhs ); friend bool operator == ( uint16 lhs, const bigint_kernel_1& rhs ); bigint_kernel_1& operator= ( uint16 rhs ); void swap ( bigint_kernel_1& item ) { data_record* temp = data; data = item.data; item.data = temp; } private: void long_add ( const data_record* lhs, const data_record* rhs, data_record* result ) const; /*! requires - result->size >= max(lhs->digits_used,rhs->digits_used) + 1 ensures - result == lhs + rhs !*/ void long_sub ( const data_record* lhs, const data_record* rhs, data_record* result ) const; /*! requires - lhs >= rhs - result->size >= lhs->digits_used ensures - result == lhs - rhs !*/ void long_div ( const data_record* lhs, const data_record* rhs, data_record* result, data_record* remainder ) const; /*! requires - rhs != 0 - result->size >= lhs->digits_used - remainder->size >= lhs->digits_used - each parameter is unique (i.e. lhs != result, lhs != remainder, etc.) ensures - result == lhs / rhs - remainder == lhs % rhs !*/ void long_mul ( const data_record* lhs, const data_record* rhs, data_record* result ) const; /*! requires - result->size >= lhs->digits_used + rhs->digits_used - result != lhs - result != rhs ensures - result == lhs * rhs !*/ void short_add ( const data_record* data, uint16 value, data_record* result ) const; /*! requires - result->size >= data->size + 1 ensures - result == data + value !*/ void short_sub ( const data_record* data, uint16 value, data_record* result ) const; /*! requires - data >= value - result->size >= data->digits_used ensures - result == data - value !*/ void short_mul ( const data_record* data, uint16 value, data_record* result ) const; /*! requires - result->size >= data->digits_used + 1 ensures - result == data * value !*/ void short_div ( const data_record* data, uint16 value, data_record* result, uint16& remainder ) const; /*! requires - value != 0 - result->size >= data->digits_used ensures - result = data*value - remainder = data%value !*/ void shift_left ( const data_record* data, data_record* result, uint32 shift_amount ) const; /*! requires - result->size >= data->digits_used + shift_amount/8 + 1 ensures - result == data << shift_amount !*/ void shift_right ( const data_record* data, data_record* result ) const; /*! requires - result->size >= data->digits_used ensures - result == data >> 1 !*/ bool is_less_than ( const data_record* lhs, const data_record* rhs ) const; /*! ensures - returns true if lhs < rhs - returns false otherwise !*/ bool is_equal_to ( const data_record* lhs, const data_record* rhs ) const; /*! ensures - returns true if lhs == rhs - returns false otherwise !*/ void increment ( const data_record* source, data_record* dest ) const; /*! requires - dest->size >= source->digits_used + 1 ensures - dest = source + 1 !*/ void decrement ( const data_record* source, data_record* dest ) const; /*! requires source != 0 ensuers dest = source - 1 !*/ // member data const uint32 slack; data_record* data; }; inline void swap ( bigint_kernel_1& a, bigint_kernel_1& b ) { a.swap(b); } inline void serialize ( const bigint_kernel_1& item, std::ostream& out ) { std::ios::fmtflags oldflags = out.flags(); out << item << ' '; out.flags(oldflags); if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c"); } inline void deserialize ( bigint_kernel_1& item, std::istream& in ) { std::ios::fmtflags oldflags = in.flags(); in >> item; in.flags(oldflags); if (in.get() != ' ') { item = 0; throw serialization_error("Error deserializing object of type bigint_kernel_c"); } } inline bool operator> (const bigint_kernel_1& a, const bigint_kernel_1& b) { return b < a; } inline bool operator!= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(a == b); } inline bool operator<= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(b < a); } inline bool operator>= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(a < b); } } #ifdef NO_MAKEFILE #include "bigint_kernel_1.cpp" #endif #endif // DLIB_BIGINT_KERNEl_1_ ================================================ FILE: dlib/bigint/bigint_kernel_2.cpp ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BIGINT_KERNEL_2_CPp_ #define DLIB_BIGINT_KERNEL_2_CPp_ #include "bigint_kernel_2.h" #include #include namespace dlib { // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member/friend function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- bigint_kernel_2:: bigint_kernel_2 ( ) : slack(25), data(new data_record(slack)) {} // ---------------------------------------------------------------------------------------- bigint_kernel_2:: bigint_kernel_2 ( uint32 value ) : slack(25), data(new data_record(slack)) { *(data->number) = static_cast(value&0xFFFF); *(data->number+1) = static_cast((value>>16)&0xFFFF); if (*(data->number+1) != 0) data->digits_used = 2; } // ---------------------------------------------------------------------------------------- bigint_kernel_2:: bigint_kernel_2 ( const bigint_kernel_2& item ) : slack(25), data(item.data) { data->references += 1; } // ---------------------------------------------------------------------------------------- bigint_kernel_2:: ~bigint_kernel_2 ( ) { if (data->references == 1) { delete data; } else { data->references -= 1; } } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 bigint_kernel_2:: operator+ ( const bigint_kernel_2& rhs ) const { data_record* temp = new data_record ( std::max(rhs.data->digits_used,data->digits_used) + slack ); long_add(data,rhs.data,temp); return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator+= ( const bigint_kernel_2& rhs ) { // if there are other references to our data if (data->references != 1) { data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); data->references -= 1; long_add(data,rhs.data,temp); data = temp; } // if data is not big enough for the result else if (data->size <= std::max(data->digits_used,rhs.data->digits_used)) { data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); long_add(data,rhs.data,temp); delete data; data = temp; } // there is enough size and no references else { long_add(data,rhs.data,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 bigint_kernel_2:: operator- ( const bigint_kernel_2& rhs ) const { data_record* temp = new data_record ( data->digits_used + slack ); long_sub(data,rhs.data,temp); return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator-= ( const bigint_kernel_2& rhs ) { // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; long_sub(data,rhs.data,temp); data = temp; } else { long_sub(data,rhs.data,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 bigint_kernel_2:: operator* ( const bigint_kernel_2& rhs ) const { data_record* temp = new data_record ( data->digits_used + rhs.data->digits_used + slack ); long_mul(data,rhs.data,temp); return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator*= ( const bigint_kernel_2& rhs ) { // create a data_record to store the result of the multiplication in data_record* temp = new data_record(rhs.data->digits_used+data->digits_used+slack); long_mul(data,rhs.data,temp); // if there are other references to data if (data->references != 1) { data->references -= 1; } else { delete data; } data = temp; return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 bigint_kernel_2:: operator/ ( const bigint_kernel_2& rhs ) const { data_record* temp = new data_record(data->digits_used+slack); data_record* remainder; try { remainder = new data_record(data->digits_used+slack); } catch (...) { delete temp; throw; } long_div(data,rhs.data,temp,remainder); delete remainder; return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator/= ( const bigint_kernel_2& rhs ) { data_record* temp = new data_record(data->digits_used+slack); data_record* remainder; try { remainder = new data_record(data->digits_used+slack); } catch (...) { delete temp; throw; } long_div(data,rhs.data,temp,remainder); // check if there are other references to data if (data->references != 1) { data->references -= 1; } // if there are no references to data then it must be deleted else { delete data; } data = temp; delete remainder; return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 bigint_kernel_2:: operator% ( const bigint_kernel_2& rhs ) const { data_record* temp = new data_record(data->digits_used+slack); data_record* remainder; try { remainder = new data_record(data->digits_used+slack); } catch (...) { delete temp; throw; } long_div(data,rhs.data,temp,remainder); delete temp; return bigint_kernel_2(remainder,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator%= ( const bigint_kernel_2& rhs ) { data_record* temp = new data_record(data->digits_used+slack); data_record* remainder; try { remainder = new data_record(data->digits_used+slack); } catch (...) { delete temp; throw; } long_div(data,rhs.data,temp,remainder); // check if there are other references to data if (data->references != 1) { data->references -= 1; } // if there are no references to data then it must be deleted else { delete data; } data = remainder; delete temp; return *this; } // ---------------------------------------------------------------------------------------- bool bigint_kernel_2:: operator < ( const bigint_kernel_2& rhs ) const { return is_less_than(data,rhs.data); } // ---------------------------------------------------------------------------------------- bool bigint_kernel_2:: operator == ( const bigint_kernel_2& rhs ) const { return is_equal_to(data,rhs.data); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator= ( const bigint_kernel_2& rhs ) { if (this == &rhs) return *this; // if we have the only reference to our data then delete it if (data->references == 1) { delete data; data = rhs.data; data->references += 1; } else { data->references -= 1; data = rhs.data; data->references += 1; } return *this; } // ---------------------------------------------------------------------------------------- std::ostream& operator<< ( std::ostream& out_, const bigint_kernel_2& rhs ) { std::ostream out(out_.rdbuf()); typedef bigint_kernel_2 bigint; bigint::data_record* temp = new bigint::data_record(*rhs.data,0); // get a char array big enough to hold the number in ascii format char* str; try { str = new char[(rhs.data->digits_used)*5+10]; } catch (...) { delete temp; throw; } char* str_start = str; str += (rhs.data->digits_used)*5+9; *str = 0; --str; uint16 remainder; rhs.short_div(temp,10000,temp,remainder); // pull the digits out of remainder char a = remainder % 10 + '0'; remainder /= 10; char b = remainder % 10 + '0'; remainder /= 10; char c = remainder % 10 + '0'; remainder /= 10; char d = remainder % 10 + '0'; remainder /= 10; *str = a; --str; *str = b; --str; *str = c; --str; *str = d; --str; // keep looping until temp represents zero while (temp->digits_used != 1 || *(temp->number) != 0) { rhs.short_div(temp,10000,temp,remainder); // pull the digits out of remainder char a = remainder % 10 + '0'; remainder /= 10; char b = remainder % 10 + '0'; remainder /= 10; char c = remainder % 10 + '0'; remainder /= 10; char d = remainder % 10 + '0'; remainder /= 10; *str = a; --str; *str = b; --str; *str = c; --str; *str = d; --str; } // throw away and extra leading zeros ++str; if (*str == '0') ++str; if (*str == '0') ++str; if (*str == '0') ++str; out << str; delete [] str_start; delete temp; return out_; } // ---------------------------------------------------------------------------------------- std::istream& operator>> ( std::istream& in_, bigint_kernel_2& rhs ) { std::istream in(in_.rdbuf()); // ignore any leading whitespaces while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n') { in.get(); } // if the first digit is not an integer then this is an error if ( !(in.peek() >= '0' && in.peek() <= '9')) { in_.clear(std::ios::failbit); return in_; } int num_read; bigint_kernel_2 temp; do { // try to get 4 chars from in num_read = 1; char a = 0; char b = 0; char c = 0; char d = 0; if (in.peek() >= '0' && in.peek() <= '9') { num_read *= 10; a = in.get(); } if (in.peek() >= '0' && in.peek() <= '9') { num_read *= 10; b = in.get(); } if (in.peek() >= '0' && in.peek() <= '9') { num_read *= 10; c = in.get(); } if (in.peek() >= '0' && in.peek() <= '9') { num_read *= 10; d = in.get(); } // merge the for digits into an uint16 uint16 num = 0; if (a != 0) { num = a - '0'; } if (b != 0) { num *= 10; num += b - '0'; } if (c != 0) { num *= 10; num += c - '0'; } if (d != 0) { num *= 10; num += d - '0'; } if (num_read != 1) { // shift the digits in temp left by the number of new digits we just read temp *= num_read; // add in new digits temp += num; } } while (num_read == 10000); rhs = temp; return in_; } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 operator+ ( uint16 lhs, const bigint_kernel_2& rhs ) { typedef bigint_kernel_2 bigint; bigint::data_record* temp = new bigint::data_record (rhs.data->digits_used+rhs.slack); rhs.short_add(rhs.data,lhs,temp); return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 operator+ ( const bigint_kernel_2& lhs, uint16 rhs ) { typedef bigint_kernel_2 bigint; bigint::data_record* temp = new bigint::data_record (lhs.data->digits_used+lhs.slack); lhs.short_add(lhs.data,rhs,temp); return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator+= ( uint16 rhs ) { // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; short_add(data,rhs,temp); data = temp; } // or if we need to enlarge data then do so else if (data->digits_used == data->size) { data_record* temp = new data_record(data->digits_used+slack); short_add(data,rhs,temp); delete data; data = temp; } // or if there is plenty of space and no references else { short_add(data,rhs,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 operator- ( uint16 lhs, const bigint_kernel_2& rhs ) { typedef bigint_kernel_2 bigint; bigint::data_record* temp = new bigint::data_record(rhs.slack); *(temp->number) = lhs - *(rhs.data->number); return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 operator- ( const bigint_kernel_2& lhs, uint16 rhs ) { typedef bigint_kernel_2 bigint; bigint::data_record* temp = new bigint::data_record (lhs.data->digits_used+lhs.slack); lhs.short_sub(lhs.data,rhs,temp); return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator-= ( uint16 rhs ) { // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; short_sub(data,rhs,temp); data = temp; } else { short_sub(data,rhs,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 operator* ( uint16 lhs, const bigint_kernel_2& rhs ) { typedef bigint_kernel_2 bigint; bigint::data_record* temp = new bigint::data_record (rhs.data->digits_used+rhs.slack); rhs.short_mul(rhs.data,lhs,temp); return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 operator* ( const bigint_kernel_2& lhs, uint16 rhs ) { typedef bigint_kernel_2 bigint; bigint::data_record* temp = new bigint::data_record (lhs.data->digits_used+lhs.slack); lhs.short_mul(lhs.data,rhs,temp); return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator*= ( uint16 rhs ) { // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; short_mul(data,rhs,temp); data = temp; } // or if we need to enlarge data else if (data->digits_used == data->size) { data_record* temp = new data_record(data->digits_used+slack); short_mul(data,rhs,temp); delete data; data = temp; } else { short_mul(data,rhs,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 operator/ ( uint16 lhs, const bigint_kernel_2& rhs ) { typedef bigint_kernel_2 bigint; bigint::data_record* temp = new bigint::data_record(rhs.slack); // if rhs might not be bigger than lhs if (rhs.data->digits_used == 1) { *(temp->number) = lhs/ *(rhs.data->number); } return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 operator/ ( const bigint_kernel_2& lhs, uint16 rhs ) { typedef bigint_kernel_2 bigint; bigint::data_record* temp = new bigint::data_record (lhs.data->digits_used+lhs.slack); uint16 remainder; lhs.short_div(lhs.data,rhs,temp,remainder); return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator/= ( uint16 rhs ) { uint16 remainder; // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; short_div(data,rhs,temp,remainder); data = temp; } else { short_div(data,rhs,data,remainder); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 operator% ( uint16 lhs, const bigint_kernel_2& rhs ) { typedef bigint_kernel_2 bigint; // temp is zero by default bigint::data_record* temp = new bigint::data_record(rhs.slack); if (rhs.data->digits_used == 1) { // if rhs is just an uint16 inside then perform the modulus *(temp->number) = lhs % *(rhs.data->number); } else { // if rhs is bigger than lhs then the answer is lhs *(temp->number) = lhs; } return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 operator% ( const bigint_kernel_2& lhs, uint16 rhs ) { typedef bigint_kernel_2 bigint; bigint::data_record* temp = new bigint::data_record(lhs.data->digits_used+lhs.slack); uint16 remainder; lhs.short_div(lhs.data,rhs,temp,remainder); temp->digits_used = 1; *(temp->number) = remainder; return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator%= ( uint16 rhs ) { uint16 remainder; // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; short_div(data,rhs,temp,remainder); data = temp; } else { short_div(data,rhs,data,remainder); } data->digits_used = 1; *(data->number) = remainder; return *this; } // ---------------------------------------------------------------------------------------- bool operator < ( uint16 lhs, const bigint_kernel_2& rhs ) { return (rhs.data->digits_used > 1 || lhs < *(rhs.data->number) ); } // ---------------------------------------------------------------------------------------- bool operator < ( const bigint_kernel_2& lhs, uint16 rhs ) { return (lhs.data->digits_used == 1 && *(lhs.data->number) < rhs); } // ---------------------------------------------------------------------------------------- bool operator == ( const bigint_kernel_2& lhs, uint16 rhs ) { return (lhs.data->digits_used == 1 && *(lhs.data->number) == rhs); } // ---------------------------------------------------------------------------------------- bool operator == ( uint16 lhs, const bigint_kernel_2& rhs ) { return (rhs.data->digits_used == 1 && *(rhs.data->number) == lhs); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator= ( uint16 rhs ) { // check if there are other references to our data if (data->references != 1) { data->references -= 1; try { data = new data_record(slack); } catch (...) { data->references += 1; throw; } } else { data->digits_used = 1; } *(data->number) = rhs; return *this; } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator++ ( ) { // if there are other references to this data then make a copy of it if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; increment(data,temp); data = temp; } // or if we need to enlarge data then do so else if (data->digits_used == data->size) { data_record* temp = new data_record(data->digits_used+slack); increment(data,temp); delete data; data = temp; } else { increment(data,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 bigint_kernel_2:: operator++ ( int ) { data_record* temp; // this is the copy of temp we will return in the end data_record* temp2 = new data_record(data->digits_used+slack); increment(data,temp2); temp = data; data = temp2; return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- bigint_kernel_2& bigint_kernel_2:: operator-- ( ) { // if there are other references to this data if (data->references != 1) { data_record* temp = new data_record(data->digits_used+slack); data->references -= 1; decrement(data,temp); data = temp; } else { decrement(data,data); } return *this; } // ---------------------------------------------------------------------------------------- const bigint_kernel_2 bigint_kernel_2:: operator-- ( int ) { data_record* temp; // this is the copy of temp we will return in the end data_record* temp2 = new data_record(data->digits_used+slack); decrement(data,temp2); temp = data; data = temp2; return bigint_kernel_2(temp,0); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // private member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: short_add ( const data_record* data, uint16 value, data_record* result ) const { // put value into the carry part of temp uint32 temp = value; temp <<= 16; const uint16* number = data->number; const uint16* end = number + data->digits_used; // one past the end of number uint16* r = result->number; while (number != end) { // add *number and the current carry temp = *number + (temp>>16); // put the low word of temp into *r *r = static_cast(temp & 0xFFFF); ++number; ++r; } // if there is a final carry if ((temp>>16) != 0) { result->digits_used = data->digits_used + 1; // store the carry in the most significant digit of the result *r = static_cast(temp>>16); } else { result->digits_used = data->digits_used; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: short_sub ( const data_record* data, uint16 value, data_record* result ) const { const uint16* number = data->number; const uint16* end = number + data->digits_used - 1; uint16* r = result->number; uint32 temp = *number - value; // put the low word of temp into *data *r = static_cast(temp & 0xFFFF); while (number != end) { ++number; ++r; // subtract the carry from *number temp = *number - (temp>>31); // put the low word of temp into *r *r = static_cast(temp & 0xFFFF); } // if we lost a digit in the subtraction if (*r == 0) { if (data->digits_used == 1) result->digits_used = 1; else result->digits_used = data->digits_used - 1; } else { result->digits_used = data->digits_used; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: short_mul ( const data_record* data, uint16 value, data_record* result ) const { uint32 temp = 0; const uint16* number = data->number; uint16* r = result->number; const uint16* end = r + data->digits_used; while ( r != end) { // multiply *data and value and add in the carry temp = *number*(uint32)value + (temp>>16); // put the low word of temp into *data *r = static_cast(temp & 0xFFFF); ++number; ++r; } // if there is a final carry if ((temp>>16) != 0) { result->digits_used = data->digits_used + 1; // put the final carry into the most significant digit of the result *r = static_cast(temp>>16); } else { result->digits_used = data->digits_used; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: short_div ( const data_record* data, uint16 value, data_record* result, uint16& rem ) const { uint16 remainder = 0; uint32 temp; const uint16* number = data->number + data->digits_used - 1; const uint16* end = number - data->digits_used; uint16* r = result->number + data->digits_used - 1; // if we are losing a digit in this division if (*number < value) { if (data->digits_used == 1) result->digits_used = 1; else result->digits_used = data->digits_used - 1; } else { result->digits_used = data->digits_used; } // perform the actual division while (number != end) { temp = *number + (((uint32)remainder)<<16); *r = static_cast(temp/value); remainder = static_cast(temp%value); --number; --r; } rem = remainder; } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: long_add ( const data_record* lhs, const data_record* rhs, data_record* result ) const { // put value into the carry part of temp uint32 temp=0; uint16* min_num; // the number with the least digits used uint16* max_num; // the number with the most digits used uint16* min_end; // one past the end of min_num uint16* max_end; // one past the end of max_num uint16* r = result->number; uint32 max_digits_used; if (lhs->digits_used < rhs->digits_used) { max_digits_used = rhs->digits_used; min_num = lhs->number; max_num = rhs->number; min_end = min_num + lhs->digits_used; max_end = max_num + rhs->digits_used; } else { max_digits_used = lhs->digits_used; min_num = rhs->number; max_num = lhs->number; min_end = min_num + rhs->digits_used; max_end = max_num + lhs->digits_used; } while (min_num != min_end) { // add *min_num, *max_num and the current carry temp = *min_num + *max_num + (temp>>16); // put the low word of temp into *r *r = static_cast(temp & 0xFFFF); ++min_num; ++max_num; ++r; } while (max_num != max_end) { // add *max_num and the current carry temp = *max_num + (temp>>16); // put the low word of temp into *r *r = static_cast(temp & 0xFFFF); ++max_num; ++r; } // check if there was a final carry if ((temp>>16) != 0) { result->digits_used = max_digits_used + 1; // put the carry into the most significant digit in the result *r = static_cast(temp>>16); } else { result->digits_used = max_digits_used; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: long_sub ( const data_record* lhs, const data_record* rhs, data_record* result ) const { const uint16* number1 = lhs->number; const uint16* number2 = rhs->number; const uint16* end = number2 + rhs->digits_used; uint16* r = result->number; uint32 temp =0; while (number2 != end) { // subtract *number2 from *number1 and then subtract any carry temp = *number1 - *number2 - (temp>>31); // put the low word of temp into *r *r = static_cast(temp & 0xFFFF); ++number1; ++number2; ++r; } end = lhs->number + lhs->digits_used; while (number1 != end) { // subtract the carry from *number1 temp = *number1 - (temp>>31); // put the low word of temp into *r *r = static_cast(temp & 0xFFFF); ++number1; ++r; } result->digits_used = lhs->digits_used; // adjust the number of digits used appropriately --r; while (*r == 0 && result->digits_used > 1) { --r; --result->digits_used; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: long_div ( const data_record* lhs, const data_record* rhs, data_record* result, data_record* remainder ) const { // zero result result->digits_used = 1; *(result->number) = 0; uint16* a; uint16* b; uint16* end; // copy lhs into remainder remainder->digits_used = lhs->digits_used; a = remainder->number; end = a + remainder->digits_used; b = lhs->number; while (a != end) { *a = *b; ++a; ++b; } // if rhs is bigger than lhs then result == 0 and remainder == lhs // so then we can quit right now if (is_less_than(lhs,rhs)) { return; } // make a temporary number data_record temp(lhs->digits_used + slack); // shift rhs left until it is one shift away from being larger than lhs and // put the number of left shifts necessary into shifts uint32 shifts; shifts = (lhs->digits_used - rhs->digits_used) * 16; shift_left(rhs,&temp,shifts); // while (lhs > temp) while (is_less_than(&temp,lhs)) { shift_left(&temp,&temp,1); ++shifts; } // make sure lhs isn't smaller than temp while (is_less_than(lhs,&temp)) { shift_right(&temp,&temp); --shifts; } // we want to execute the loop shifts +1 times ++shifts; while (shifts != 0) { shift_left(result,result,1); // if (temp <= remainder) if (!is_less_than(remainder,&temp)) { long_sub(remainder,&temp,remainder); // increment result uint16* r = result->number; uint16* end = r + result->digits_used; while (true) { ++(*r); // if there was no carry then we are done if (*r != 0) break; ++r; // if we hit the end of r and there is still a carry then // the next digit of r is 1 and there is one more digit used if (r == end) { *r = 1; ++(result->digits_used); break; } } } shift_right(&temp,&temp); --shifts; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: long_mul ( const data_record* lhs, const data_record* rhs, data_record* result ) const { // if one of the numbers is small then use this simple but O(n^2) algorithm if (std::min(lhs->digits_used, rhs->digits_used) < 10) { // make result be zero result->digits_used = 1; *(result->number) = 0; const data_record* aa; const data_record* bb; if (lhs->digits_used < rhs->digits_used) { // make copies of lhs and rhs and give them an appropriate amount of // extra memory so there won't be any overflows aa = lhs; bb = rhs; } else { // make copies of lhs and rhs and give them an appropriate amount of // extra memory so there won't be any overflows aa = rhs; bb = lhs; } // copy the larger(approximately) of lhs and rhs into b data_record b(*bb,aa->digits_used+slack); uint32 shift_value = 0; uint16* anum = aa->number; uint16* end = anum + aa->digits_used; while (anum != end ) { uint16 bit = 0x0001; for (int i = 0; i < 16; ++i) { // if the specified bit of a is 1 if ((*anum & bit) != 0) { shift_left(&b,&b,shift_value); shift_value = 0; long_add(&b,result,result); } ++shift_value; bit <<= 1; } ++anum; } } else // else if both lhs and rhs are large then use the more complex // O(n*logn) algorithm { uint32 size = 1; // make size a power of 2 while (size < (lhs->digits_used + rhs->digits_used)*2) { size *= 2; } // allocate some temporary space so we can do the FFT ct* a = new ct[size]; ct* b; try {b = new ct[size]; } catch (...) { delete [] a; throw; } // load lhs into the a array. We are breaking the input number into // 8bit chunks for the purpose of using this fft algorithm. The reason // for this is so that we have smaller numbers coming out of the final // ifft. This helps avoid overflow. for (uint32 i = 0; i < lhs->digits_used; ++i) { a[i*2] = ct((t)(lhs->number[i]&0xFF),0); a[i*2+1] = ct((t)(lhs->number[i]>>8),0); } for (uint32 i = lhs->digits_used*2; i < size; ++i) { a[i] = 0; } // load rhs into the b array for (uint32 i = 0; i < rhs->digits_used; ++i) { b[i*2] = ct((t)(rhs->number[i]&0xFF),0); b[i*2+1] = ct((t)(rhs->number[i]>>8),0); } for (uint32 i = rhs->digits_used*2; i < size; ++i) { b[i] = 0; } // perform the forward fft of a and b fft(a,size); fft(b,size); const double l = 1.0/size; // do the pointwise multiply of a and b and also apply the scale // factor in this loop too. for (unsigned long i = 0; i < size; ++i) { a[i] = l*a[i]*b[i]; } // Now compute the inverse fft of the pointwise multiplication of a and b. // This is basically the result. We just have to take care of any carries // that should happen. ifft(a,size); // loop over the result and propagate any carries that need to take place. // We will also be moving the resulting numbers into result->number at // the same time. uint64 carry = 0; result->digits_used = 0; int zeros = 0; const uint32 len = lhs->digits_used + rhs->digits_used; for (unsigned long i = 0; i < len; ++i) { uint64 num1 = static_cast(std::round(a[i*2].real())); num1 += carry; carry = 0; if (num1 > 255) { carry = num1 >> 8; num1 = (num1&0xFF); } uint64 num2 = static_cast(std::round(a[i*2+1].real())); num2 += carry; carry = 0; if (num2 > 255) { carry = num2 >> 8; num2 = (num2&0xFF); } // put the new number into its final place num1 = (num2<<8) | num1; result->number[i] = static_cast(num1); // keep track of the number of leading zeros if (num1 == 0) ++zeros; else zeros = 0; ++(result->digits_used); } // adjust digits_used so that it reflects the actual number // of non-zero digits in our representation. result->digits_used -= zeros; // if the result was zero then adjust the result accordingly if (result->digits_used == 0) { // make result be zero result->digits_used = 1; *(result->number) = 0; } // free all the temporary buffers delete [] a; delete [] b; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: shift_left ( const data_record* data, data_record* result, uint32 shift_amount ) const { uint32 offset = shift_amount/16; shift_amount &= 0xf; // same as shift_amount %= 16; uint16* r = result->number + data->digits_used + offset; // result uint16* end = data->number; uint16* s = end + data->digits_used; // source const uint32 temp = 16 - shift_amount; *r = (*(--s) >> temp); // set the number of digits used in the result // if the upper bits from *s were zero then don't count this first word if (*r == 0) { result->digits_used = data->digits_used + offset; } else { result->digits_used = data->digits_used + offset + 1; } --r; while (s != end) { *r = ((*s << shift_amount) | ( *(s-1) >> temp)); --r; --s; } *r = *s << shift_amount; // now zero the rest of the result end = result->number; while (r != end) *(--r) = 0; } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: shift_right ( const data_record* data, data_record* result ) const { uint16* r = result->number; // result uint16* s = data->number; // source uint16* end = s + data->digits_used - 1; while (s != end) { *r = (*s >> 1) | (*(s+1) << 15); ++r; ++s; } *r = *s >> 1; // calculate the new number for digits_used if (*r == 0) { if (data->digits_used != 1) result->digits_used = data->digits_used - 1; else result->digits_used = 1; } else { result->digits_used = data->digits_used; } } // ---------------------------------------------------------------------------------------- bool bigint_kernel_2:: is_less_than ( const data_record* lhs, const data_record* rhs ) const { uint32 lhs_digits_used = lhs->digits_used; uint32 rhs_digits_used = rhs->digits_used; // if lhs is definitely less than rhs if (lhs_digits_used < rhs_digits_used ) return true; // if lhs is definitely greater than rhs else if (lhs_digits_used > rhs_digits_used) return false; else { uint16* end = lhs->number; uint16* l = end + lhs_digits_used; uint16* r = rhs->number + rhs_digits_used; while (l != end) { --l; --r; if (*l < *r) return true; else if (*l > *r) return false; } // at this point we know that they are equal return false; } } // ---------------------------------------------------------------------------------------- bool bigint_kernel_2:: is_equal_to ( const data_record* lhs, const data_record* rhs ) const { // if lhs and rhs are definitely not equal if (lhs->digits_used != rhs->digits_used ) { return false; } else { uint16* l = lhs->number; uint16* r = rhs->number; uint16* end = l + lhs->digits_used; while (l != end) { if (*l != *r) return false; ++l; ++r; } // at this point we know that they are equal return true; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: increment ( const data_record* source, data_record* dest ) const { uint16* s = source->number; uint16* d = dest->number; uint16* end = s + source->digits_used; while (true) { *d = *s + 1; // if there was no carry then break out of the loop if (*d != 0) { dest->digits_used = source->digits_used; // copy the rest of the digits over to d ++d; ++s; while (s != end) { *d = *s; ++d; ++s; } break; } ++s; // if we have hit the end of s and there was a carry up to this point // then just make the next digit 1 and add one to the digits used if (s == end) { ++d; dest->digits_used = source->digits_used + 1; *d = 1; break; } ++d; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: decrement ( const data_record* source, data_record* dest ) const { uint16* s = source->number; uint16* d = dest->number; uint16* end = s + source->digits_used; while (true) { *d = *s - 1; // if there was no carry then break out of the loop if (*d != 0xFFFF) { // if we lost a digit in the subtraction if (*d == 0 && s+1 == end) { if (source->digits_used == 1) dest->digits_used = 1; else dest->digits_used = source->digits_used - 1; } else { dest->digits_used = source->digits_used; } break; } else { ++d; ++s; } } // copy the rest of the digits over to d ++d; ++s; while (s != end) { *d = *s; ++d; ++s; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: fft ( ct* data, unsigned long len ) const { const t pi2 = -2.0*3.1415926535897932384626433832795028841971693993751; const unsigned long half = len/2; std::vector twiddle_factors; twiddle_factors.resize(half); // compute the complex root of unity w const t temp = pi2/len; ct w = ct(std::cos(temp),std::sin(temp)); ct w_pow = 1; // compute the twiddle factors for (std::vector::size_type j = 0; j < twiddle_factors.size(); ++j) { twiddle_factors[j] = w_pow; w_pow *= w; } ct a, b; // now compute the decimation in frequency. This first // outer loop loops log2(len) number of times unsigned long skip = 1; for (unsigned long step = half; step != 0; step >>= 1) { // do blocks of butterflies in this loop for (unsigned long j = 0; j < len; j += step*2) { // do step butterflies for (unsigned long k = 0; k < step; ++k) { const unsigned long a_idx = j+k; const unsigned long b_idx = j+k+step; a = data[a_idx] + data[b_idx]; b = (data[a_idx] - data[b_idx])*twiddle_factors[k*skip]; data[a_idx] = a; data[b_idx] = b; } } skip *= 2; } } // ---------------------------------------------------------------------------------------- void bigint_kernel_2:: ifft( ct* data, unsigned long len ) const { const t pi2 = 2.0*3.1415926535897932384626433832795028841971693993751; const unsigned long half = len/2; std::vector twiddle_factors; twiddle_factors.resize(half); // compute the complex root of unity w const t temp = pi2/len; ct w = ct(std::cos(temp),std::sin(temp)); ct w_pow = 1; // compute the twiddle factors for (std::vector::size_type j = 0; j < twiddle_factors.size(); ++j) { twiddle_factors[j] = w_pow; w_pow *= w; } ct a, b; // now compute the inverse decimation in frequency. This first // outer loop loops log2(len) number of times unsigned long skip = half; for (unsigned long step = 1; step <= half; step <<= 1) { // do blocks of butterflies in this loop for (unsigned long j = 0; j < len; j += step*2) { // do step butterflies for (unsigned long k = 0; k < step; ++k) { const unsigned long a_idx = j+k; const unsigned long b_idx = j+k+step; data[b_idx] *= twiddle_factors[k*skip]; a = data[a_idx] + data[b_idx]; b = data[a_idx] - data[b_idx]; data[a_idx] = a; data[b_idx] = b; } } skip /= 2; } } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BIGINT_KERNEL_2_CPp_ ================================================ FILE: dlib/bigint/bigint_kernel_2.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BIGINT_KERNEl_2_ #define DLIB_BIGINT_KERNEl_2_ #include "bigint_kernel_abstract.h" #include "../algs.h" #include "../serialize.h" #include "../uintn.h" #include #include #include #include namespace dlib { class bigint_kernel_2 { /*! INITIAL VALUE slack == 25 data->number[0] == 0 data->size == slack data->references == 1 data->digits_used == 1 CONVENTION slack == the number of extra digits placed into the number when it is created. the slack value should never be less than 1 data->number == pointer to an array of data->size uint16s. data represents a string of base 65535 numbers with data[0] being the least significant bit and data[data->digits_used-1] being the most significant NOTE: In the comments I will consider a word to be a 16 bit value data->digits_used == the number of significant digits in the number. data->digits_used tells us the number of used elements in the data->number array so everything beyond data->number[data->digits_used-1] is undefined data->references == the number of bigint_kernel_2 objects which refer to this data_record !*/ struct data_record { explicit data_record( uint32 size_ ) : size(size_), number(new uint16[size_]), references(1), digits_used(1) {*number = 0;} /*! ensures - initializes *this to represent zero !*/ data_record( const data_record& item, uint32 additional_size ) : size(item.digits_used + additional_size), number(new uint16[size]), references(1), digits_used(item.digits_used) { uint16* source = item.number; uint16* dest = number; uint16* end = source + digits_used; while (source != end) { *dest = *source; ++dest; ++source; } } /*! ensures - *this is a copy of item except with size == item.digits_used + additional_size !*/ ~data_record( ) { delete [] number; } const uint32 size; uint16* number; uint32 references; uint32 digits_used; private: // no copy constructor data_record ( data_record&); }; // note that the second parameter is just there // to resolve the ambiguity between this constructor and // bigint_kernel_2(uint32) explicit bigint_kernel_2 ( data_record* data_, int ): slack(25),data(data_) {} /*! ensures - *this is initialized with data_ as its data member !*/ public: bigint_kernel_2 ( ); bigint_kernel_2 ( uint32 value ); bigint_kernel_2 ( const bigint_kernel_2& item ); virtual ~bigint_kernel_2 ( ); const bigint_kernel_2 operator+ ( const bigint_kernel_2& rhs ) const; bigint_kernel_2& operator+= ( const bigint_kernel_2& rhs ); const bigint_kernel_2 operator- ( const bigint_kernel_2& rhs ) const; bigint_kernel_2& operator-= ( const bigint_kernel_2& rhs ); const bigint_kernel_2 operator* ( const bigint_kernel_2& rhs ) const; bigint_kernel_2& operator*= ( const bigint_kernel_2& rhs ); const bigint_kernel_2 operator/ ( const bigint_kernel_2& rhs ) const; bigint_kernel_2& operator/= ( const bigint_kernel_2& rhs ); const bigint_kernel_2 operator% ( const bigint_kernel_2& rhs ) const; bigint_kernel_2& operator%= ( const bigint_kernel_2& rhs ); bool operator < ( const bigint_kernel_2& rhs ) const; bool operator == ( const bigint_kernel_2& rhs ) const; bigint_kernel_2& operator= ( const bigint_kernel_2& rhs ); friend std::ostream& operator<< ( std::ostream& out, const bigint_kernel_2& rhs ); friend std::istream& operator>> ( std::istream& in, bigint_kernel_2& rhs ); bigint_kernel_2& operator++ ( ); const bigint_kernel_2 operator++ ( int ); bigint_kernel_2& operator-- ( ); const bigint_kernel_2 operator-- ( int ); friend const bigint_kernel_2 operator+ ( uint16 lhs, const bigint_kernel_2& rhs ); friend const bigint_kernel_2 operator+ ( const bigint_kernel_2& lhs, uint16 rhs ); bigint_kernel_2& operator+= ( uint16 rhs ); friend const bigint_kernel_2 operator- ( uint16 lhs, const bigint_kernel_2& rhs ); friend const bigint_kernel_2 operator- ( const bigint_kernel_2& lhs, uint16 rhs ); bigint_kernel_2& operator-= ( uint16 rhs ); friend const bigint_kernel_2 operator* ( uint16 lhs, const bigint_kernel_2& rhs ); friend const bigint_kernel_2 operator* ( const bigint_kernel_2& lhs, uint16 rhs ); bigint_kernel_2& operator*= ( uint16 rhs ); friend const bigint_kernel_2 operator/ ( uint16 lhs, const bigint_kernel_2& rhs ); friend const bigint_kernel_2 operator/ ( const bigint_kernel_2& lhs, uint16 rhs ); bigint_kernel_2& operator/= ( uint16 rhs ); friend const bigint_kernel_2 operator% ( uint16 lhs, const bigint_kernel_2& rhs ); friend const bigint_kernel_2 operator% ( const bigint_kernel_2& lhs, uint16 rhs ); bigint_kernel_2& operator%= ( uint16 rhs ); friend bool operator < ( uint16 lhs, const bigint_kernel_2& rhs ); friend bool operator < ( const bigint_kernel_2& lhs, uint16 rhs ); friend bool operator == ( const bigint_kernel_2& lhs, uint16 rhs ); friend bool operator == ( uint16 lhs, const bigint_kernel_2& rhs ); bigint_kernel_2& operator= ( uint16 rhs ); void swap ( bigint_kernel_2& item ) { data_record* temp = data; data = item.data; item.data = temp; } private: typedef double t; typedef std::complex ct; void fft( ct* data, unsigned long len ) const; /*! requires - len == x^n for some integer n (i.e. len is a power of 2) - len > 0 ensures - #data == the FT decimation in frequency of data !*/ void ifft( ct* data, unsigned long len ) const; /*! requires - len == x^n for some integer n (i.e. len is a power of 2) - len > 0 ensures - #data == the inverse decimation in frequency of data. (i.e. the inverse of what fft(data,len,-1) does to data) !*/ void long_add ( const data_record* lhs, const data_record* rhs, data_record* result ) const; /*! requires - result->size >= max(lhs->digits_used,rhs->digits_used) + 1 ensures - result == lhs + rhs !*/ void long_sub ( const data_record* lhs, const data_record* rhs, data_record* result ) const; /*! requires - lhs >= rhs - result->size >= lhs->digits_used ensures - result == lhs - rhs !*/ void long_div ( const data_record* lhs, const data_record* rhs, data_record* result, data_record* remainder ) const; /*! requires - rhs != 0 - result->size >= lhs->digits_used - remainder->size >= lhs->digits_used - each parameter is unique (i.e. lhs != result, lhs != remainder, etc.) ensures - result == lhs / rhs - remainder == lhs % rhs !*/ void long_mul ( const data_record* lhs, const data_record* rhs, data_record* result ) const; /*! requires - result->size >= lhs->digits_used + rhs->digits_used - result != lhs - result != rhs ensures - result == lhs * rhs !*/ void short_add ( const data_record* data, uint16 value, data_record* result ) const; /*! requires - result->size >= data->size + 1 ensures - result == data + value !*/ void short_sub ( const data_record* data, uint16 value, data_record* result ) const; /*! requires - data >= value - result->size >= data->digits_used ensures - result == data - value !*/ void short_mul ( const data_record* data, uint16 value, data_record* result ) const; /*! requires - result->size >= data->digits_used + 1 ensures - result == data * value !*/ void short_div ( const data_record* data, uint16 value, data_record* result, uint16& remainder ) const; /*! requires - value != 0 - result->size >= data->digits_used ensures - result = data*value - remainder = data%value !*/ void shift_left ( const data_record* data, data_record* result, uint32 shift_amount ) const; /*! requires - result->size >= data->digits_used + shift_amount/8 + 1 ensures - result == data << shift_amount !*/ void shift_right ( const data_record* data, data_record* result ) const; /*! requires - result->size >= data->digits_used ensures - result == data >> 1 !*/ bool is_less_than ( const data_record* lhs, const data_record* rhs ) const; /*! ensures - returns true if lhs < rhs - returns false otherwise !*/ bool is_equal_to ( const data_record* lhs, const data_record* rhs ) const; /*! ensures - returns true if lhs == rhs - returns false otherwise !*/ void increment ( const data_record* source, data_record* dest ) const; /*! requires - dest->size >= source->digits_used + 1 ensures - dest = source + 1 !*/ void decrement ( const data_record* source, data_record* dest ) const; /*! requires source != 0 ensuers dest = source - 1 !*/ // member data const uint32 slack; data_record* data; }; inline void swap ( bigint_kernel_2& a, bigint_kernel_2& b ) { a.swap(b); } inline void serialize ( const bigint_kernel_2& item, std::ostream& out ) { std::ios::fmtflags oldflags = out.flags(); out << item << ' '; out.flags(oldflags); if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c"); } inline void deserialize ( bigint_kernel_2& item, std::istream& in ) { std::ios::fmtflags oldflags = in.flags(); in >> item; in.flags(oldflags); if (in.get() != ' ') { item = 0; throw serialization_error("Error deserializing object of type bigint_kernel_c"); } } inline bool operator> (const bigint_kernel_2& a, const bigint_kernel_2& b) { return b < a; } inline bool operator!= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(a == b); } inline bool operator<= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(b < a); } inline bool operator>= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(a < b); } } #ifdef NO_MAKEFILE #include "bigint_kernel_2.cpp" #endif #endif // DLIB_BIGINT_KERNEl_2_ ================================================ FILE: dlib/bigint/bigint_kernel_abstract.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_BIGINT_KERNEl_ABSTRACT_ #ifdef DLIB_BIGINT_KERNEl_ABSTRACT_ #include #include "../algs.h" #include "../serialize.h" #include "../uintn.h" namespace dlib { class bigint { /*! INITIAL VALUE *this == 0 WHAT THIS OBJECT REPRESENTS This object represents an arbitrary precision unsigned integer the following operators are supported: operator + operator += operator - operator -= operator * operator *= operator / operator /= operator % operator %= operator == operator < operator = operator << (for writing to ostreams) operator >> (for reading from istreams) operator++ // pre increment operator++(int) // post increment operator-- // pre decrement operator--(int) // post decrement the other comparison operators(>, !=, <=, and >=) are available and come from the templates in dlib::relational_operators THREAD SAFETY bigint may be reference counted so it is very unthread safe. use with care in a multithreaded program !*/ public: bigint ( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc if this is thrown the bigint will be unusable but will not leak memory !*/ bigint ( uint32 value ); /*! requires - value <= (2^32)-1 ensures - #*this is properly initialized - #*this == value throws - std::bad_alloc if this is thrown the bigint will be unusable but will not leak memory !*/ bigint ( const bigint& item ); /*! ensures - #*this is properly initialized - #*this == value throws - std::bad_alloc if this is thrown the bigint will be unusable but will not leak memory !*/ virtual ~bigint ( ); /*! ensures - all resources associated with #*this have been released !*/ const bigint operator+ ( const bigint& rhs ) const; /*! ensures - returns the result of adding rhs to *this throws - std::bad_alloc if this function throws then it has no effect !*/ bigint& operator+= ( const bigint& rhs ); /*! ensures - #*this == *this + rhs - returns #*this throws - std::bad_alloc if this function throws then it has no effect !*/ const bigint operator- ( const bigint& rhs ) const; /*! requires - *this >= rhs ensures - returns the result of subtracting rhs from *this throws - std::bad_alloc if this function throws then it has no effect !*/ bigint& operator-= ( const bigint& rhs ); /*! requires - *this >= rhs ensures - #*this == *this - rhs - returns #*this throws - std::bad_alloc if this function throws then it has no effect !*/ const bigint operator* ( const bigint& rhs ) const; /*! ensures - returns the result of multiplying *this and rhs throws - std::bad_alloc if this function throws then it has no effect !*/ bigint& operator*= ( const bigint& rhs ); /*! ensures - #*this == *this * rhs - returns #*this throws - std::bad_alloc if this function throws then it has no effect !*/ const bigint operator/ ( const bigint& rhs ) const; /*! requires - rhs != 0 ensures - returns the result of dividing *this by rhs throws - std::bad_alloc if this function throws then it has no effect !*/ bigint& operator/= ( const bigint& rhs ); /*! requires - rhs != 0 ensures - #*this == *this / rhs - returns #*this throws - std::bad_alloc if this function throws then it has no effect !*/ const bigint operator% ( const bigint& rhs ) const; /*! requires - rhs != 0 ensures - returns the result of *this mod rhs throws - std::bad_alloc if this function throws then it has no effect !*/ bigint& operator%= ( const bigint& rhs ); /*! requires - rhs != 0 ensures - #*this == *this % rhs - returns #*this throws - std::bad_alloc if this function throws then it has no effect !*/ bool operator < ( const bigint& rhs ) const; /*! ensures - returns true if *this is less than rhs - returns false otherwise !*/ bool operator == ( const bigint& rhs ) const; /*! ensures - returns true if *this and rhs represent the same number - returns false otherwise !*/ bigint& operator= ( const bigint& rhs ); /*! ensures - #*this == rhs throws - std::bad_alloc if this function throws then it has no effect !*/ friend std::ostream& operator<< ( std::ostream& out, const bigint& rhs ); /*! ensures - the number in *this has been written to #out as a base ten number throws - std::bad_alloc if this function throws then it has no effect (nothing is written to out) !*/ friend std::istream& operator>> ( std::istream& in, bigint& rhs ); /*! ensures - reads a number from in and puts it into #*this - if (there is no positive base ten number on the input stream ) then - #in.fail() == true throws - std::bad_alloc if this function throws the value in rhs is undefined and some characters may have been read from in. rhs is still usable though, its value is just unknown. !*/ bigint& operator++ ( ); /*! ensures - #*this == *this + 1 - returns #*this throws - std::bad_alloc if this function throws then it has no effect !*/ const bigint operator++ ( int ); /*! ensures - #*this == *this + 1 - returns *this throws - std::bad_alloc if this function throws then it has no effect !*/ bigint& operator-- ( ); /*! requires - *this != 0 ensures - #*this == *this - 1 - returns #*this throws - std::bad_alloc if this function throws then it has no effect !*/ const bigint operator-- ( int ); /*! requires - *this != 0 ensures - #*this == *this - 1 - returns *this throws - std::bad_alloc if this function throws then it has no effect !*/ void swap ( bigint& item ); /*! ensures - swaps *this and item !*/ // ------------------------------------------------------------------ // ---- The following functions are identical to the above ----- // ---- but take uint16 as one of their arguments. They --- // ---- exist only to allow for a more efficient implementation --- // ------------------------------------------------------------------ friend const bigint operator+ ( uint16 lhs, const bigint& rhs ); /*! requires - lhs <= 65535 ensures - returns the result of adding rhs to lhs throws - std::bad_alloc if this function throws then it has no effect !*/ friend const bigint operator+ ( const bigint& lhs, uint16 rhs ); /*! requires - rhs <= 65535 ensures - returns the result of adding rhs to lhs throws - std::bad_alloc if this function throws then it has no effect !*/ bigint& operator+= ( uint16 rhs ); /*! requires - rhs <= 65535 ensures - #*this == *this + rhs - returns #this throws - std::bad_alloc if this function throws then it has no effect !*/ friend const bigint operator- ( uint16 lhs, const bigint& rhs ); /*! requires - lhs >= rhs - lhs <= 65535 ensures - returns the result of subtracting rhs from lhs throws - std::bad_alloc if this function throws then it has no effect !*/ friend const bigint operator- ( const bigint& lhs, uint16 rhs ); /*! requires - lhs >= rhs - rhs <= 65535 ensures - returns the result of subtracting rhs from lhs throws - std::bad_alloc if this function throws then it has no effect !*/ bigint& operator-= ( uint16 rhs ); /*! requires - *this >= rhs - rhs <= 65535 ensures - #*this == *this - rhs - returns #*this throws - std::bad_alloc if this function throws then it has no effect !*/ friend const bigint operator* ( uint16 lhs, const bigint& rhs ); /*! requires - lhs <= 65535 ensures - returns the result of multiplying lhs and rhs throws - std::bad_alloc if this function throws then it has no effect !*/ friend const bigint operator* ( const bigint& lhs, uint16 rhs ); /*! requires - rhs <= 65535 ensures - returns the result of multiplying lhs and rhs throws - std::bad_alloc if this function throws then it has no effect !*/ bigint& operator*= ( uint16 rhs ); /*! requires - rhs <= 65535 ensures - #*this == *this * rhs - returns #*this throws - std::bad_alloc if this function throws then it has no effect !*/ friend const bigint operator/ ( uint16 lhs, const bigint& rhs ); /*! requires - rhs != 0 - lhs <= 65535 ensures - returns the result of dividing lhs by rhs throws - std::bad_alloc if this function throws then it has no effect !*/ friend const bigint operator/ ( const bigint& lhs, uint16 rhs ); /*! requires - rhs != 0 - rhs <= 65535 ensures - returns the result of dividing lhs by rhs throws - std::bad_alloc if this function throws then it has no effect !*/ bigint& operator/= ( uint16 rhs ); /*! requires - rhs != 0 - rhs <= 65535 ensures - #*this == *this / rhs - returns #*this throws - std::bad_alloc if this function throws then it has no effect !*/ friend const bigint operator% ( uint16 lhs, const bigint& rhs ); /*! requires - rhs != 0 - lhs <= 65535 ensures - returns the result of lhs mod rhs throws - std::bad_alloc if this function throws then it has no effect !*/ friend const bigint operator% ( const bigint& lhs, uint16 rhs ); /*! requires - rhs != 0 - rhs <= 65535 ensures - returns the result of lhs mod rhs throws - std::bad_alloc if this function throws then it has no effect !*/ bigint& operator%= ( uint16 rhs ); /*! requires - rhs != 0 - rhs <= 65535 ensures - #*this == *this % rhs - returns #*this throws - std::bad_alloc if this function throws then it has no effect !*/ friend bool operator < ( uint16 lhs, const bigint& rhs ); /*! requires - lhs <= 65535 ensures - returns true if lhs is less than rhs - returns false otherwise !*/ friend bool operator < ( const bigint& lhs, uint16 rhs ); /*! requires - rhs <= 65535 ensures - returns true if lhs is less than rhs - returns false otherwise !*/ friend bool operator == ( const bigint& lhs, uint16 rhs ); /*! requires - rhs <= 65535 ensures - returns true if lhs and rhs represent the same number - returns false otherwise !*/ friend bool operator == ( uint16 lhs, const bigint& rhs ); /*! requires - lhs <= 65535 ensures - returns true if lhs and rhs represent the same number - returns false otherwise !*/ bigint& operator= ( uint16 rhs ); /*! requires - rhs <= 65535 ensures - #*this == rhs - returns #*this throws - std::bad_alloc if this function throws then it has no effect !*/ }; inline void swap ( bigint& a, bigint& b ) { a.swap(b); } /*! provides a global swap function !*/ void serialize ( const bigint& item, std::istream& in ); /*! provides serialization support !*/ void deserialize ( bigint& item, std::istream& in ); /*! provides deserialization support !*/ inline bool operator> (const bigint& a, const bigint& b) { return b < a; } inline bool operator!= (const bigint& a, const bigint& b) { return !(a == b); } inline bool operator<= (const bigint& a, const bigint& b) { return !(b < a); } inline bool operator>= (const bigint& a, const bigint& b) { return !(a < b); } } #endif // DLIB_BIGINT_KERNEl_ABSTRACT_ ================================================ FILE: dlib/bigint/bigint_kernel_c.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BIGINT_KERNEl_C_ #define DLIB_BIGINT_KERNEl_C_ #include "bigint_kernel_abstract.h" #include "../algs.h" #include "../assert.h" #include namespace dlib { // ---------------------------------------------------------------------------------------- template < typename bigint_base > class bigint_kernel_c { bigint_base data; explicit bigint_kernel_c ( const bigint_base& item ) : data(item) {} public: bigint_kernel_c ( ); bigint_kernel_c ( uint32 value ); bigint_kernel_c ( const bigint_kernel_c& item ); ~bigint_kernel_c ( ); const bigint_kernel_c operator+ ( const bigint_kernel_c& rhs ) const; bigint_kernel_c& operator+= ( const bigint_kernel_c& rhs ); const bigint_kernel_c operator- ( const bigint_kernel_c& rhs ) const; bigint_kernel_c& operator-= ( const bigint_kernel_c& rhs ); const bigint_kernel_c operator* ( const bigint_kernel_c& rhs ) const; bigint_kernel_c& operator*= ( const bigint_kernel_c& rhs ); const bigint_kernel_c operator/ ( const bigint_kernel_c& rhs ) const; bigint_kernel_c& operator/= ( const bigint_kernel_c& rhs ); const bigint_kernel_c operator% ( const bigint_kernel_c& rhs ) const; bigint_kernel_c& operator%= ( const bigint_kernel_c& rhs ); bool operator < ( const bigint_kernel_c& rhs ) const; bool operator == ( const bigint_kernel_c& rhs ) const; bigint_kernel_c& operator= ( const bigint_kernel_c& rhs ); template friend std::ostream& operator<< ( std::ostream& out, const bigint_kernel_c& rhs ); template friend std::istream& operator>> ( std::istream& in, bigint_kernel_c& rhs ); bigint_kernel_c& operator++ ( ); const bigint_kernel_c operator++ ( int ); bigint_kernel_c& operator-- ( ); const bigint_kernel_c operator-- ( int ); template friend const bigint_kernel_c operator+ ( uint16 lhs, const bigint_kernel_c& rhs ); template friend const bigint_kernel_c operator+ ( const bigint_kernel_c& lhs, uint16 rhs ); bigint_kernel_c& operator+= ( uint16 rhs ); template friend const bigint_kernel_c operator- ( uint16 lhs, const bigint_kernel_c& rhs ); template friend const bigint_kernel_c operator- ( const bigint_kernel_c& lhs, uint16 rhs ); bigint_kernel_c& operator-= ( uint16 rhs ); template friend const bigint_kernel_c operator* ( uint16 lhs, const bigint_kernel_c& rhs ); template friend const bigint_kernel_c operator* ( const bigint_kernel_c& lhs, uint16 rhs ); bigint_kernel_c& operator*= ( uint16 rhs ); template friend const bigint_kernel_c operator/ ( uint16 lhs, const bigint_kernel_c& rhs ); template friend const bigint_kernel_c operator/ ( const bigint_kernel_c& lhs, uint16 rhs ); bigint_kernel_c& operator/= ( uint16 rhs ); template friend const bigint_kernel_c operator% ( uint16 lhs, const bigint_kernel_c& rhs ); template friend const bigint_kernel_c operator% ( const bigint_kernel_c& lhs, uint16 rhs ); bigint_kernel_c& operator%= ( uint16 rhs ); template friend bool operator < ( uint16 lhs, const bigint_kernel_c& rhs ); template friend bool operator < ( const bigint_kernel_c& lhs, uint16 rhs ); template friend bool operator == ( const bigint_kernel_c& lhs, uint16 rhs ); template friend bool operator == ( uint16 lhs, const bigint_kernel_c& rhs ); bigint_kernel_c& operator= ( uint16 rhs ); void swap ( bigint_kernel_c& item ) { data.swap(item.data); } }; template < typename bigint_base > void swap ( bigint_kernel_c& a, bigint_kernel_c& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > inline void serialize ( const bigint_kernel_c& item, std::ostream& out ) { std::ios::fmtflags oldflags = out.flags(); out << item << ' '; out.flags(oldflags); if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c"); } template < typename bigint_base > inline void deserialize ( bigint_kernel_c& item, std::istream& in ) { std::ios::fmtflags oldflags = in.flags(); in >> item; in.flags(oldflags); if (in.get() != ' ') { item = 0; throw serialization_error("Error deserializing object of type bigint_kernel_c"); } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c:: bigint_kernel_c ( ) {} // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c:: bigint_kernel_c ( uint32 value ) : data(value) { // make sure requires clause is not broken DLIB_CASSERT( value <= 0xFFFFFFFF , "\tbigint::bigint(uint16)" << "\n\t value must be <= (2^32)-1" << "\n\tthis: " << this << "\n\tvalue: " << value ); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c:: bigint_kernel_c ( const bigint_kernel_c& item ) : data(item.data) {} // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c:: ~bigint_kernel_c ( ) {} // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c bigint_kernel_c:: operator+ ( const bigint_kernel_c& rhs ) const { return bigint_kernel_c(data + rhs.data); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator+= ( const bigint_kernel_c& rhs ) { data += rhs.data; return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c bigint_kernel_c:: operator- ( const bigint_kernel_c& rhs ) const { // make sure requires clause is not broken DLIB_CASSERT( !(*this < rhs), "\tconst bigint bigint::operator-(const bigint&)" << "\n\t *this should not be less than rhs" << "\n\tthis: " << this << "\n\t*this: " << *this << "\n\trhs: " << rhs ); // call the real function return bigint_kernel_c(data-rhs.data); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator-= ( const bigint_kernel_c& rhs ) { // make sure requires clause is not broken DLIB_CASSERT( !(*this < rhs), "\tbigint& bigint::operator-=(const bigint&)" << "\n\t *this should not be less than rhs" << "\n\tthis: " << this << "\n\t*this: " << *this << "\n\trhs: " << rhs ); // call the real function data -= rhs.data; return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c bigint_kernel_c:: operator* ( const bigint_kernel_c& rhs ) const { return bigint_kernel_c(data * rhs.data ); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator*= ( const bigint_kernel_c& rhs ) { data *= rhs.data; return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c bigint_kernel_c:: operator/ ( const bigint_kernel_c& rhs ) const { //make sure requires clause is not broken DLIB_CASSERT( !(rhs == 0), "\tconst bigint bigint::operator/(const bigint&)" << "\n\t can't divide by zero" << "\n\tthis: " << this ); // call the real function return bigint_kernel_c(data/rhs.data); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator/= ( const bigint_kernel_c& rhs ) { // make sure requires clause is not broken DLIB_CASSERT( !(rhs == 0), "\tbigint& bigint::operator/=(const bigint&)" << "\n\t can't divide by zero" << "\n\tthis: " << this ); // call the real function data /= rhs.data; return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c bigint_kernel_c:: operator% ( const bigint_kernel_c& rhs ) const { // make sure requires clause is not broken DLIB_CASSERT( !(rhs == 0), "\tconst bigint bigint::operator%(const bigint&)" << "\n\t can't divide by zero" << "\n\tthis: " << this ); // call the real function return bigint_kernel_c(data%rhs.data); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator%= ( const bigint_kernel_c& rhs ) { // make sure requires clause is not broken DLIB_CASSERT( !(rhs == 0), "\tbigint& bigint::operator%=(const bigint&)" << "\n\t can't divide by zero" << "\n\tthis: " << this ); // call the real function data %= rhs.data; return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bool bigint_kernel_c:: operator < ( const bigint_kernel_c& rhs ) const { return data < rhs.data; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bool bigint_kernel_c:: operator == ( const bigint_kernel_c& rhs ) const { return data == rhs.data; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator= ( const bigint_kernel_c& rhs ) { data = rhs.data; return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > std::ostream& operator<< ( std::ostream& out, const bigint_kernel_c& rhs ) { out << rhs.data; return out; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > std::istream& operator>> ( std::istream& in, bigint_kernel_c& rhs ) { in >> rhs.data; return in; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator++ ( ) { ++data; return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c bigint_kernel_c:: operator++ ( int ) { return bigint_kernel_c(data++); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator-- ( ) { // make sure requires clause is not broken DLIB_CASSERT( !(*this == 0), "\tbigint& bigint::operator--()" << "\n\t *this to subtract from *this it must not be zero to begin with" << "\n\tthis: " << this ); // call the real function --data; return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c bigint_kernel_c:: operator-- ( int ) { // make sure requires clause is not broken DLIB_CASSERT( !(*this == 0), "\tconst bigint bigint::operator--(int)" << "\n\t *this to subtract from *this it must not be zero to begin with" << "\n\tthis: " << this ); // call the real function return bigint_kernel_c(data--); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c operator+ ( uint16 l, const bigint_kernel_c& rhs ) { uint32 lhs = l; // make sure requires clause is not broken DLIB_CASSERT( lhs <= 65535, "\tconst bigint operator+(uint16, const bigint&)" << "\n\t lhs must be <= 65535" << "\n\trhs: " << rhs << "\n\tlhs: " << lhs ); return bigint_kernel_c(static_cast(lhs)+rhs.data); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c operator+ ( const bigint_kernel_c& lhs, uint16 r ) { uint32 rhs = r; // make sure requires clause is not broken DLIB_CASSERT( rhs <= 65535, "\tconst bigint operator+(const bigint&, uint16)" << "\n\t rhs must be <= 65535" << "\n\trhs: " << rhs << "\n\tlhs: " << lhs ); return bigint_kernel_c(lhs.data+static_cast(rhs)); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator+= ( uint16 r ) { uint32 rhs = r; // make sure requires clause is not broken DLIB_CASSERT( rhs <= 65535, "\tbigint& bigint::operator+=(uint16)" << "\n\t rhs must be <= 65535" << "\n\tthis: " << this << "\n\t*this: " << *this << "\n\trhs: " << rhs ); data += rhs; return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c operator- ( uint16 l, const bigint_kernel_c& rhs ) { uint32 lhs = l; // make sure requires clause is not broken DLIB_CASSERT( !(static_cast(lhs) < rhs) && lhs <= 65535, "\tconst bigint operator-(uint16,const bigint&)" << "\n\t lhs must be greater than or equal to rhs and lhs <= 65535" << "\n\tlhs: " << lhs << "\n\trhs: " << rhs << "\n\t&lhs: " << &lhs << "\n\t&rhs: " << &rhs ); // call the real function return bigint_kernel_c(static_cast(lhs)-rhs.data); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c operator- ( const bigint_kernel_c& lhs, uint16 r ) { uint32 rhs = r; // make sure requires clause is not broken DLIB_CASSERT( !(lhs < static_cast(rhs)) && rhs <= 65535, "\tconst bigint operator-(const bigint&,uint16)" << "\n\t lhs must be greater than or equal to rhs and rhs <= 65535" << "\n\tlhs: " << lhs << "\n\trhs: " << rhs << "\n\t&lhs: " << &lhs << "\n\t&rhs: " << &rhs ); // call the real function return bigint_kernel_c(lhs.data-static_cast(rhs)); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator-= ( uint16 r ) { uint32 rhs = r; // make sure requires clause is not broken DLIB_CASSERT( !(*this < static_cast(rhs)) && rhs <= 65535, "\tbigint& bigint::operator-=(uint16)" << "\n\t *this must not be less than rhs and rhs <= 65535" << "\n\tthis: " << this << "\n\t*this: " << *this << "\n\trhs: " << rhs ); // call the real function data -= static_cast(rhs); return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c operator* ( uint16 l, const bigint_kernel_c& rhs ) { uint32 lhs = l; // make sure requires clause is not broken DLIB_CASSERT( lhs <= 65535, "\tconst bigint operator*(uint16, const bigint&)" << "\n\t lhs must be <= 65535" << "\n\trhs: " << rhs << "\n\tlhs: " << lhs ); return bigint_kernel_c(lhs*rhs.data); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c operator* ( const bigint_kernel_c& lhs, uint16 r ) { uint32 rhs = r; // make sure requires clause is not broken DLIB_CASSERT( rhs <= 65535, "\tconst bigint operator*(const bigint&, uint16)" << "\n\t rhs must be <= 65535" << "\n\trhs: " << rhs << "\n\tlhs: " << lhs ); return bigint_kernel_c(lhs.data*rhs); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator*= ( uint16 r ) { uint32 rhs = r; // make sure requires clause is not broken DLIB_CASSERT( rhs <= 65535, "\t bigint bigint::operator*=(uint16)" << "\n\t rhs must be <= 65535" << "\n\tthis: " << this << "\n\t*this: " << *this << "\n\trhs: " << rhs ); data *= static_cast(rhs); return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c operator/ ( uint16 l, const bigint_kernel_c& rhs ) { uint32 lhs = l; // make sure requires clause is not broken DLIB_CASSERT( !(rhs == 0) && lhs <= 65535, "\tconst bigint operator/(uint16,const bigint&)" << "\n\t you can't divide by zero and lhs <= 65535" << "\n\t&lhs: " << &lhs << "\n\t&rhs: " << &rhs << "\n\tlhs: " << lhs ); // call the real function return bigint_kernel_c(lhs/rhs.data); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c operator/ ( const bigint_kernel_c& lhs, uint16 r ) { uint32 rhs = r; // make sure requires clause is not broken DLIB_CASSERT( !(rhs == 0) && rhs <= 65535, "\tconst bigint operator/(const bigint&,uint16)" << "\n\t you can't divide by zero and rhs <= 65535" << "\n\t&lhs: " << &lhs << "\n\t&rhs: " << &rhs << "\n\trhs: " << rhs ); // call the real function return bigint_kernel_c(lhs.data/static_cast(rhs)); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator/= ( uint16 rhs ) { // make sure requires clause is not broken DLIB_CASSERT( !(rhs == 0) && static_cast(rhs) <= 65535, "\tbigint& bigint::operator/=(uint16)" << "\n\t you can't divide by zero and rhs must be <= 65535" << "\n\tthis: " << this << "\n\trhs: " << rhs ); // call the real function data /= rhs; return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c operator% ( uint16 lhs, const bigint_kernel_c& rhs ) { // make sure requires clause is not broken DLIB_CASSERT( !(rhs == 0) && static_cast(lhs) <= 65535, "\tconst bigint operator%(uint16,const bigint&)" << "\n\t you can't divide by zero and lhs must be <= 65535" << "\n\t&lhs: " << &lhs << "\n\t&rhs: " << &rhs << "\n\tlhs: " << lhs ); // call the real function return bigint_kernel_c(lhs%rhs.data); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > const bigint_kernel_c operator% ( const bigint_kernel_c& lhs, uint16 r ) { uint32 rhs = r; // make sure requires clause is not broken DLIB_CASSERT( !(rhs == 0) && rhs <= 65535, "\tconst bigint operator%(const bigint&,uint16)" << "\n\t you can't divide by zero and rhs must be <= 65535" << "\n\t&lhs: " << &lhs << "\n\t&rhs: " << &rhs << "\n\trhs: " << rhs ); // call the real function return bigint_kernel_c(lhs.data%static_cast(rhs)); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator%= ( uint16 r ) { uint32 rhs = r; // make sure requires clause is not broken DLIB_CASSERT( !(rhs == 0) && rhs <= 65535, "\tbigint& bigint::operator%=(uint16)" << "\n\t you can't divide by zero and rhs must be <= 65535" << "\n\tthis: " << this << "\n\trhs: " << rhs ); // call the real function data %= rhs; return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bool operator < ( uint16 l, const bigint_kernel_c& rhs ) { uint32 lhs = l; // make sure requires clause is not broken DLIB_CASSERT( lhs <= 65535, "\tbool operator<(uint16, const bigint&)" << "\n\t lhs must be <= 65535" << "\n\trhs: " << rhs << "\n\tlhs: " << lhs ); return static_cast(lhs) < rhs.data; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bool operator < ( const bigint_kernel_c& lhs, uint16 r ) { uint32 rhs = r; // make sure requires clause is not broken DLIB_CASSERT( rhs <= 65535, "\tbool operator<(const bigint&, uint16)" << "\n\t rhs must be <= 65535" << "\n\trhs: " << rhs << "\n\tlhs: " << lhs ); return lhs.data < static_cast(rhs); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bool operator == ( const bigint_kernel_c& lhs, uint16 r ) { uint32 rhs = r; // make sure requires clause is not broken DLIB_CASSERT( rhs <= 65535, "\tbool operator==(const bigint&, uint16)" << "\n\t rhs must be <= 65535" << "\n\trhs: " << rhs << "\n\tlhs: " << lhs ); return lhs.data == static_cast(rhs); } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bool operator == ( uint16 l, const bigint_kernel_c& rhs ) { uint32 lhs = l; // make sure requires clause is not broken DLIB_CASSERT( lhs <= 65535, "\tbool operator==(uint16, const bigint&)" << "\n\t lhs must be <= 65535" << "\n\trhs: " << rhs << "\n\tlhs: " << lhs ); return static_cast(lhs) == rhs.data; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > bigint_kernel_c& bigint_kernel_c:: operator= ( uint16 r ) { uint32 rhs = r; // make sure requires clause is not broken DLIB_CASSERT( rhs <= 65535, "\tbigint bigint::operator=(uint16)" << "\n\t rhs must be <= 65535" << "\n\t*this: " << *this << "\n\tthis: " << this << "\n\tlhs: " << rhs ); data = static_cast(rhs); return *this; } // ---------------------------------------------------------------------------------------- template < typename bigint_base > inline bool operator> (const bigint_kernel_c& a, const bigint_kernel_c& b) { return b < a; } template < typename bigint_base > inline bool operator!= (const bigint_kernel_c& a, const bigint_kernel_c& b) { return !(a == b); } template < typename bigint_base > inline bool operator<= (const bigint_kernel_c& a, const bigint_kernel_c& b) { return !(b < a); } template < typename bigint_base > inline bool operator>= (const bigint_kernel_c& a, const bigint_kernel_c& b) { return !(a < b); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BIGINT_KERNEl_C_ ================================================ FILE: dlib/bigint.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BIGINt_ #define DLIB_BIGINt_ #include "bigint/bigint_kernel_1.h" #include "bigint/bigint_kernel_2.h" #include "bigint/bigint_kernel_c.h" namespace dlib { class bigint { bigint() {} public: //----------- kernels --------------- // kernel_1a typedef bigint_kernel_1 kernel_1a; typedef bigint_kernel_c kernel_1a_c; // kernel_2a typedef bigint_kernel_2 kernel_2a; typedef bigint_kernel_c kernel_2a_c; }; } #endif // DLIB_BIGINt_ ================================================ FILE: dlib/binary_search_tree/binary_search_tree_kernel_1.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_1_ #define DLIB_BINARY_SEARCH_TREE_KERNEl_1_ #include "binary_search_tree_kernel_abstract.h" #include "../algs.h" #include "../interfaces/map_pair.h" #include "../interfaces/enumerable.h" #include "../interfaces/remover.h" #include "../serialize.h" #include #include namespace dlib { template < typename domain, typename range, typename mem_manager, typename compare = std::less > class binary_search_tree_kernel_1 : public enumerable >, public asc_pair_remover { /*! INITIAL VALUE tree_size == 0 tree_root == 0 tree_height == 0 at_start_ == true current_element == 0 stack == array of 50 node pointers stack_pos == 0 CONVENTION tree_size == size() tree_height == height() stack[stack_pos-1] == pop() current_element_valid() == (current_element != 0) if (current_element_valid()) then element() == current_element->d and current_element->r at_start_ == at_start() if (current_element != 0 && current_element != tree_root) then stack[stack_pos-1] == the parent of the node pointed to by current_element if (tree_size != 0) tree_root == pointer to the root node of the binary search tree else tree_root == 0 for all nodes: { left points to the left subtree or 0 if there is no left subtree and right points to the right subtree or 0 if there is no right subtree and all elements in a left subtree are <= the root and all elements in a right subtree are >= the root and d is the item in the domain of *this contained in the node r is the item in the range of *this contained in the node balance: balance == 0 if both subtrees have the same height balance == -1 if the left subtree has a height that is greater than the height of the right subtree by 1 balance == 1 if the right subtree has a height that is greater than the height of the left subtree by 1 for all trees: the height of the left and right subtrees differ by at most one } !*/ class node { public: node* left; node* right; domain d; range r; signed char balance; }; class mpair : public map_pair { public: const domain* d; range* r; const domain& key( ) const { return *d; } const range& value( ) const { return *r; } range& value( ) { return *r; } }; public: typedef domain domain_type; typedef range range_type; typedef compare compare_type; typedef mem_manager mem_manager_type; binary_search_tree_kernel_1( ) : tree_size(0), tree_root(0), current_element(0), tree_height(0), at_start_(true), stack_pos(0), stack(ppool.allocate_array(50)) { } virtual ~binary_search_tree_kernel_1( ); inline void clear( ); inline short height ( ) const; inline unsigned long count ( const domain& item ) const; inline void add ( domain& d, range& r ); void remove ( const domain& d, domain& d_copy, range& r ); void destroy ( const domain& item ); inline const range* operator[] ( const domain& item ) const; inline range* operator[] ( const domain& item ); inline void swap ( binary_search_tree_kernel_1& item ); // function from the asc_pair_remover interface void remove_any ( domain& d, range& r ); // functions from the enumerable interface inline size_t size ( ) const; bool at_start ( ) const; inline void reset ( ) const; bool current_element_valid ( ) const; const map_pair& element ( ) const; map_pair& element ( ); bool move_next ( ) const; void remove_last_in_order ( domain& d, range& r ); void remove_current_element ( domain& d, range& r ); void position_enumerator ( const domain& d ) const; private: inline void rotate_left ( node*& t ); /*! requires - t->balance == 2 - t->right->balance == 0 or 1 - t == reference to the pointer in t's parent node that points to t ensures - #t is still a binary search tree - #t->balance is between 1 and -1 - #t now has a height smaller by 1 if #t->balance == 0 !*/ inline void rotate_right ( node*& t ); /*! requires - t->balance == -2 - t->left->balance == 0 or -1 - t == reference to the pointer in t's parent node that points to t ensures - #t is still a binary search tree - #t->balance is between 1 and -1 - #t now has a height smaller by 1 if #t->balance == 0 !*/ inline void double_rotate_right ( node*& t ); /*! requires - t->balance == -2 - t->left->balance == 1 - t == reference to the pointer in t's parent node that points to t ensures - #t is still a binary search tree - #t now has a balance of 0 - #t now has a height smaller by 1 !*/ inline void double_rotate_left ( node*& t ); /*! requires - t->balance == 2 - t->right->balance == -1 - t == reference to the pointer in t's parent node that points to t ensures - #t is still a binary search tree - #t now has a balance of 0 - #t now has a height smaller by 1 !*/ bool remove_biggest_element_in_tree ( node*& t, domain& d, range& r ); /*! requires - t != 0 (i.e. there must be something in the tree to remove) - t == reference to the pointer in t's parent node that points to t ensures - the biggest node in t has been removed - the biggest node domain element in t has been put into #d - the biggest node range element in t has been put into #r - #t is still a binary search tree - returns false if the height of the tree has not changed - returns true if the height of the tree has shrunk by one !*/ bool remove_least_element_in_tree ( node*& t, domain& d, range& r ); /*! requires - t != 0 (i.e. there must be something in the tree to remove) - t == reference to the pointer in t's parent node that points to t ensures - the least node in t has been removed - the least node domain element in t has been put into #d - the least node range element in t has been put into #r - #t is still a binary search tree - returns false if the height of the tree has not changed - returns true if the height of the tree has shrunk by one !*/ bool add_to_tree ( node*& t, domain& d, range& r ); /*! requires - t == reference to the pointer in t's parent node that points to t ensures - the mapping (d --> r) has been added to #t - #d and #r have initial values for their types - #t is still a binary search tree - returns false if the height of the tree has not changed - returns true if the height of the tree has grown by one !*/ bool remove_from_tree ( node*& t, const domain& d, domain& d_copy, range& r ); /*! requires - return_reference(t,d) != 0 - t == reference to the pointer in t's parent node that points to t ensures - #d_copy is equivalent to d - an element in t equivalent to d has been removed and swapped into #d_copy and its associated range object has been swapped into #r - #t is still a binary search tree - returns false if the height of the tree has not changed - returns true if the height of the tree has shrunk by one !*/ bool remove_from_tree ( node*& t, const domain& item ); /*! requires - return_reference(t,item) != 0 - t == reference to the pointer in t's parent node that points to t ensures - an element in t equivalent to item has been removed - #t is still a binary search tree - returns false if the height of the tree has not changed - returns true if the height of the tree has shrunk by one !*/ const range* return_reference ( const node* t, const domain& d ) const; /*! ensures - if (there is a domain element equivalent to d in t) then - returns a pointer to the element in the range equivalent to d - else - returns 0 !*/ range* return_reference ( node* t, const domain& d ); /*! ensures - if (there is a domain element equivalent to d in t) then - returns a pointer to the element in the range equivalent to d - else - returns 0 !*/ inline bool keep_node_balanced ( node*& t ); /*! requires - t != 0 - t == reference to the pointer in t's parent node that points to t ensures - if (t->balance is < 1 or > 1) then - keep_node_balanced() will ensure that #t->balance == 0, -1, or 1 - #t is still a binary search tree - returns true if it made the tree one height shorter - returns false if it didn't change the height !*/ unsigned long get_count ( const domain& item, node* tree_root ) const; /*! requires - tree_root == the root of a binary search tree or 0 ensures - if (tree_root == 0) then - returns 0 - else - returns the number of elements in tree_root that are equivalent to item !*/ void delete_tree ( node* t ); /*! requires - t != 0 ensures - deallocates the node pointed to by t and all of t's left and right children !*/ void push ( node* n ) const { stack[stack_pos] = n; ++stack_pos; } /*! ensures - pushes n onto the stack !*/ node* pop ( ) const { --stack_pos; return stack[stack_pos]; } /*! ensures - pops the top of the stack and returns it !*/ bool fix_stack ( node* t, unsigned char depth = 0 ); /*! requires - current_element != 0 - depth == 0 - t == tree_root ensures - makes the stack contain the correct set of parent pointers. also adjusts stack_pos so it is correct. - #t is still a binary search tree !*/ bool remove_current_element_from_tree ( node*& t, domain& d, range& r, unsigned long cur_stack_pos = 1 ); /*! requires - t == tree_root - cur_stack_pos == 1 - current_element != 0 ensures - removes the data in the node given by current_element and swaps it into #d and #r. - #t is still a binary search tree - the enumerator is advances on to the next element but its stack is potentially corrupted. so you must call fix_stack(tree_root) to fix it. - returns false if the height of the tree has not changed - returns true if the height of the tree has shrunk by one !*/ // data members mutable mpair p; unsigned long tree_size; node* tree_root; mutable node* current_element; typename mem_manager::template rebind::other pool; typename mem_manager::template rebind::other ppool; short tree_height; mutable bool at_start_; mutable unsigned char stack_pos; mutable node** stack; compare comp; // restricted functions binary_search_tree_kernel_1(binary_search_tree_kernel_1&); binary_search_tree_kernel_1& operator=(binary_search_tree_kernel_1&); }; template < typename domain, typename range, typename mem_manager, typename compare > inline void swap ( binary_search_tree_kernel_1& a, binary_search_tree_kernel_1& b ) { a.swap(b); } template < typename domain, typename range, typename mem_manager, typename compare > void deserialize ( binary_search_tree_kernel_1& item, std::istream& in ) { try { item.clear(); unsigned long size; deserialize(size,in); domain d; range r; for (unsigned long i = 0; i < size; ++i) { deserialize(d,in); deserialize(r,in); item.add(d,r); } } catch (serialization_error& e) { item.clear(); throw serialization_error(e.info + "\n while deserializing object of type binary_search_tree_kernel_1"); } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > binary_search_tree_kernel_1:: ~binary_search_tree_kernel_1 ( ) { ppool.deallocate_array(stack); if (tree_size != 0) { delete_tree(tree_root); } } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: clear ( ) { if (tree_size > 0) { delete_tree(tree_root); tree_root = 0; tree_size = 0; tree_height = 0; } // reset the enumerator reset(); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > size_t binary_search_tree_kernel_1:: size ( ) const { return tree_size; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > short binary_search_tree_kernel_1:: height ( ) const { return tree_height; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > unsigned long binary_search_tree_kernel_1:: count ( const domain& item ) const { return get_count(item,tree_root); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: add ( domain& d, range& r ) { tree_height += add_to_tree(tree_root,d,r); ++tree_size; // reset the enumerator reset(); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: remove ( const domain& d, domain& d_copy, range& r ) { tree_height -= remove_from_tree(tree_root,d,d_copy,r); --tree_size; // reset the enumerator reset(); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: destroy ( const domain& item ) { tree_height -= remove_from_tree(tree_root,item); --tree_size; // reset the enumerator reset(); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: remove_any ( domain& d, range& r ) { tree_height -= remove_least_element_in_tree(tree_root,d,r); --tree_size; // reset the enumerator reset(); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > range* binary_search_tree_kernel_1:: operator[] ( const domain& item ) { return return_reference(tree_root,item); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > const range* binary_search_tree_kernel_1:: operator[] ( const domain& item ) const { return return_reference(tree_root,item); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: swap ( binary_search_tree_kernel_1& item ) { pool.swap(item.pool); ppool.swap(item.ppool); exchange(p,item.p); exchange(stack,item.stack); exchange(stack_pos,item.stack_pos); exchange(comp,item.comp); node* tree_root_temp = item.tree_root; unsigned long tree_size_temp = item.tree_size; short tree_height_temp = item.tree_height; node* current_element_temp = item.current_element; bool at_start_temp = item.at_start_; item.tree_root = tree_root; item.tree_size = tree_size; item.tree_height = tree_height; item.current_element = current_element; item.at_start_ = at_start_; tree_root = tree_root_temp; tree_size = tree_size_temp; tree_height = tree_height_temp; current_element = current_element_temp; at_start_ = at_start_temp; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: remove_last_in_order ( domain& d, range& r ) { tree_height -= remove_biggest_element_in_tree(tree_root,d,r); --tree_size; // reset the enumerator reset(); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: remove_current_element ( domain& d, range& r ) { tree_height -= remove_current_element_from_tree(tree_root,d,r); --tree_size; // fix the enumerator stack if we need to if (current_element) fix_stack(tree_root); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: position_enumerator ( const domain& d ) const { // clear the enumerator state and make sure the stack is empty reset(); at_start_ = false; node* t = tree_root; bool went_left = false; while (t != 0) { if ( comp(d , t->d) ) { push(t); // if item is on the left then look in left t = t->left; went_left = true; } else if (comp(t->d , d)) { push(t); // if item is on the right then look in right t = t->right; went_left = false; } else { current_element = t; return; } } // if we didn't find any matches but there might be something after the // d in this tree. if (stack_pos > 0) { current_element = pop(); // if we went left from this node then this node is the next // biggest. if (went_left) { return; } else { move_next(); } } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // enumerable function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_1:: at_start ( ) const { return at_start_; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: reset ( ) const { at_start_ = true; current_element = 0; stack_pos = 0; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_1:: current_element_valid ( ) const { return (current_element != 0); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > const map_pair& binary_search_tree_kernel_1:: element ( ) const { p.d = &(current_element->d); p.r = &(current_element->r); return p; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > map_pair& binary_search_tree_kernel_1:: element ( ) { p.d = &(current_element->d); p.r = &(current_element->r); return p; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_1:: move_next ( ) const { // if we haven't started iterating yet if (at_start_) { at_start_ = false; if (tree_size == 0) { return false; } else { // find the first element in the tree current_element = tree_root; node* temp = current_element->left; while (temp != 0) { push(current_element); current_element = temp; temp = current_element->left; } return true; } } else { if (current_element == 0) { return false; } else { node* temp; bool went_up; // true if we went up the tree from a child node to parent bool from_left = false; // true if we went up and were coming from a left child node // find the next element in the tree if (current_element->right != 0) { // go right and down temp = current_element; push(current_element); current_element = temp->right; went_up = false; } else { // go up to the parent if we can if (current_element == tree_root) { // in this case we have iterated over all the element of the tree current_element = 0; return false; } went_up = true; node* parent = pop(); from_left = (parent->left == current_element); // go up to parent current_element = parent; } while (true) { if (went_up) { if (from_left) { // in this case we have found the next node break; } else { if (current_element == tree_root) { // in this case we have iterated over all the elements // in the tree current_element = 0; return false; } // we should go up node* parent = pop(); from_left = (parent->left == current_element); current_element = parent; } } else { // we just went down to a child node if (current_element->left != 0) { // go left went_up = false; temp = current_element; push(current_element); current_element = temp->left; } else { // if there is no left child then we have found the next node break; } } } return true; } } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // private member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: delete_tree ( node* t ) { if (t->left != 0) delete_tree(t->left); if (t->right != 0) delete_tree(t->right); pool.deallocate(t); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: rotate_left ( node*& t ) { // set the new balance numbers if (t->right->balance == 1) { t->balance = 0; t->right->balance = 0; } else { t->balance = 1; t->right->balance = -1; } // perform the rotation node* temp = t->right; t->right = temp->left; temp->left = t; t = temp; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: rotate_right ( node*& t ) { // set the new balance numbers if (t->left->balance == -1) { t->balance = 0; t->left->balance = 0; } else { t->balance = -1; t->left->balance = 1; } // preform the rotation node* temp = t->left; t->left = temp->right; temp->right = t; t = temp; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: double_rotate_right ( node*& t ) { node* temp = t; t = t->left->right; temp->left->right = t->left; t->left = temp->left; temp->left = t->right; t->right = temp; if (t->balance < 0) { t->left->balance = 0; t->right->balance = 1; } else if (t->balance > 0) { t->left->balance = -1; t->right->balance = 0; } else { t->left->balance = 0; t->right->balance = 0; } t->balance = 0; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_1:: double_rotate_left ( node*& t ) { node* temp = t; t = t->right->left; temp->right->left = t->right; t->right = temp->right; temp->right = t->left; t->left = temp; if (t->balance < 0) { t->left->balance = 0; t->right->balance = 1; } else if (t->balance > 0) { t->left->balance = -1; t->right->balance = 0; } else { t->left->balance = 0; t->right->balance = 0; } t->balance = 0; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_1:: remove_biggest_element_in_tree ( node*& t, domain& d, range& r ) { // make a reference to the current node so we don't have to dereference a // pointer a bunch of times node& tree = *t; // if the right tree is an empty tree if ( tree.right == 0) { // swap nodes domain and range elements into d and r exchange(d,tree.d); exchange(r,tree.r); // plug hole left by removing this node t = tree.left; // delete the node that was just removed pool.deallocate(&tree); // return that the height of this part of the tree has decreased return true; } else { // keep going right // if remove made the tree one height shorter if ( remove_biggest_element_in_tree(tree.right,d,r) ) { // if this caused the current tree to strink then report that if ( tree.balance == 1) { --tree.balance; return true; } else { --tree.balance; return keep_node_balanced(t); } } return false; } } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_1:: remove_least_element_in_tree ( node*& t, domain& d, range& r ) { // make a reference to the current node so we don't have to dereference a // pointer a bunch of times node& tree = *t; // if the left tree is an empty tree if ( tree.left == 0) { // swap nodes domain and range elements into d and r exchange(d,tree.d); exchange(r,tree.r); // plug hole left by removing this node t = tree.right; // delete the node that was just removed pool.deallocate(&tree); // return that the height of this part of the tree has decreased return true; } else { // keep going left // if remove made the tree one height shorter if ( remove_least_element_in_tree(tree.left,d,r) ) { // if this caused the current tree to strink then report that if ( tree.balance == -1) { ++tree.balance; return true; } else { ++tree.balance; return keep_node_balanced(t); } } return false; } } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_1:: add_to_tree ( node*& t, domain& d, range& r ) { // if found place to add if (t == 0) { // create a node to add new item into t = pool.allocate(); // make a reference to the current node so we don't have to dereference a // pointer a bunch of times node& tree = *t; // set left and right pointers to NULL to indicate that there are no // left or right subtrees tree.left = 0; tree.right = 0; tree.balance = 0; // put d and r into t exchange(tree.d,d); exchange(tree.r,r); // indicate that the height of this tree has increased return true; } else // keep looking for a place to add the new item { // make a reference to the current node so we don't have to dereference // a pointer a bunch of times node& tree = *t; signed char old_balance = tree.balance; // add the new item to whatever subtree it should go into if (comp( d , tree.d) ) tree.balance -= add_to_tree(tree.left,d,r); else tree.balance += add_to_tree(tree.right,d,r); // if the tree was balanced to start with if (old_balance == 0) { // if its not balanced anymore then it grew in height if (tree.balance != 0) return true; else return false; } else { // if the tree is now balanced then it didn't grow if (tree.balance == 0) { return false; } else { // if the tree needs to be balanced if (tree.balance != old_balance) { return !keep_node_balanced(t); } // if there has been no change in the heights else { return false; } } } } } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_1:: fix_stack ( node* t, unsigned char depth ) { // if we found the node we were looking for if (t == current_element) { stack_pos = depth; return true; } else if (t == 0) { return false; } if (!( comp(t->d , current_element->d))) { // go left if (fix_stack(t->left,depth+1)) { stack[depth] = t; return true; } } if (!(comp(current_element->d , t->d))) { // go right if (fix_stack(t->right,depth+1)) { stack[depth] = t; return true; } } return false; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_1:: remove_current_element_from_tree ( node*& t, domain& d, range& r, unsigned long cur_stack_pos ) { // make a reference to the current node so we don't have to dereference // a pointer a bunch of times node& tree = *t; // if we found the node we were looking for if (t == current_element) { // swap nodes domain and range elements into d_copy and r exchange(d,tree.d); exchange(r,tree.r); // if there is no left node if (tree.left == 0) { // move the enumerator on to the next element before we mess with the // tree move_next(); // plug hole left by removing this node and free memory t = tree.right; // plug hole with right subtree // delete old node pool.deallocate(&tree); // indicate that the height has changed return true; } // if there is no right node else if (tree.right == 0) { // move the enumerator on to the next element before we mess with the // tree move_next(); // plug hole left by removing this node and free memory t = tree.left; // plug hole with left subtree // delete old node pool.deallocate(&tree); // indicate that the height of this tree has changed return true; } // if there are both a left and right sub node else { // in this case the next current element is going to get swapped back // into this t node. current_element = t; // get an element that can replace the one being removed and do this // if it made the right subtree shrink by one if (remove_least_element_in_tree(tree.right,tree.d,tree.r)) { // adjust the tree height --tree.balance; // if the height of the current tree has dropped by one if (tree.balance == 0) { return true; } else { return keep_node_balanced(t); } } // else this remove did not effect the height of this tree else { return false; } } } else if ( (cur_stack_pos < stack_pos && stack[cur_stack_pos] == tree.left) || tree.left == current_element ) { // go left if (tree.balance == -1) { int balance = tree.balance; balance += remove_current_element_from_tree(tree.left,d,r,cur_stack_pos+1); tree.balance = balance; return !tree.balance; } else { int balance = tree.balance; balance += remove_current_element_from_tree(tree.left,d,r,cur_stack_pos+1); tree.balance = balance; return keep_node_balanced(t); } } else if ( (cur_stack_pos < stack_pos && stack[cur_stack_pos] == tree.right) || tree.right == current_element ) { // go right if (tree.balance == 1) { int balance = tree.balance; balance -= remove_current_element_from_tree(tree.right,d,r,cur_stack_pos+1); tree.balance = balance; return !tree.balance; } else { int balance = tree.balance; balance -= remove_current_element_from_tree(tree.right,d,r,cur_stack_pos+1); tree.balance = balance; return keep_node_balanced(t); } } // this return should never happen but do it anyway to suppress compiler warnings return false; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_1:: remove_from_tree ( node*& t, const domain& d, domain& d_copy, range& r ) { // make a reference to the current node so we don't have to dereference // a pointer a bunch of times node& tree = *t; // if item is on the left if (comp(d , tree.d)) { // if the left side of the tree has the greatest height if (tree.balance == -1) { int balance = tree.balance; balance += remove_from_tree(tree.left,d,d_copy,r); tree.balance = balance; return !tree.balance; } else { int balance = tree.balance; balance += remove_from_tree(tree.left,d,d_copy,r); tree.balance = balance; return keep_node_balanced(t); } } // if item is on the right else if (comp(tree.d , d)) { // if the right side of the tree has the greatest height if (tree.balance == 1) { int balance = tree.balance; balance -= remove_from_tree(tree.right,d,d_copy,r); tree.balance = balance; return !tree.balance; } else { int balance = tree.balance; balance -= remove_from_tree(tree.right,d,d_copy,r); tree.balance = balance; return keep_node_balanced(t); } } // if item is found else { // swap nodes domain and range elements into d_copy and r exchange(d_copy,tree.d); exchange(r,tree.r); // if there is no left node if (tree.left == 0) { // plug hole left by removing this node and free memory t = tree.right; // plug hole with right subtree // delete old node pool.deallocate(&tree); // indicate that the height has changed return true; } // if there is no right node else if (tree.right == 0) { // plug hole left by removing this node and free memory t = tree.left; // plug hole with left subtree // delete old node pool.deallocate(&tree); // indicate that the height of this tree has changed return true; } // if there are both a left and right sub node else { // get an element that can replace the one being removed and do this // if it made the right subtree shrink by one if (remove_least_element_in_tree(tree.right,tree.d,tree.r)) { // adjust the tree height --tree.balance; // if the height of the current tree has dropped by one if (tree.balance == 0) { return true; } else { return keep_node_balanced(t); } } // else this remove did not effect the height of this tree else { return false; } } } } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_1:: remove_from_tree ( node*& t, const domain& d ) { // make a reference to the current node so we don't have to dereference // a pointer a bunch of times node& tree = *t; // if item is on the left if (comp(d , tree.d)) { // if the left side of the tree has the greatest height if (tree.balance == -1) { int balance = tree.balance; balance += remove_from_tree(tree.left,d); tree.balance = balance; return !tree.balance; } else { int balance = tree.balance; balance += remove_from_tree(tree.left,d); tree.balance = balance; return keep_node_balanced(t); } } // if item is on the right else if (comp(tree.d , d)) { // if the right side of the tree has the greatest height if (tree.balance == 1) { int balance = tree.balance; balance -= remove_from_tree(tree.right,d); tree.balance = balance; return !tree.balance; } else { int balance = tree.balance; balance -= remove_from_tree(tree.right,d); tree.balance = balance; return keep_node_balanced(t); } } // if item is found else { // if there is no left node if (tree.left == 0) { // plug hole left by removing this node and free memory t = tree.right; // plug hole with right subtree // delete old node pool.deallocate(&tree); // indicate that the height has changed return true; } // if there is no right node else if (tree.right == 0) { // plug hole left by removing this node and free memory t = tree.left; // plug hole with left subtree // delete old node pool.deallocate(&tree); // indicate that the height of this tree has changed return true; } // if there are both a left and right sub node else { // get an element that can replace the one being removed and do this // if it made the right subtree shrink by one if (remove_least_element_in_tree(tree.right,tree.d,tree.r)) { // adjust the tree height --tree.balance; // if the height of the current tree has dropped by one if (tree.balance == 0) { return true; } else { return keep_node_balanced(t); } } // else this remove did not effect the height of this tree else { return false; } } } } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > range* binary_search_tree_kernel_1:: return_reference ( node* t, const domain& d ) { while (t != 0) { if ( comp(d , t->d )) { // if item is on the left then look in left t = t->left; } else if (comp(t->d , d)) { // if item is on the right then look in right t = t->right; } else { // if it's found then return a reference to it return &(t->r); } } return 0; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > const range* binary_search_tree_kernel_1:: return_reference ( const node* t, const domain& d ) const { while (t != 0) { if ( comp(d , t->d) ) { // if item is on the left then look in left t = t->left; } else if (comp(t->d , d)) { // if item is on the right then look in right t = t->right; } else { // if it's found then return a reference to it return &(t->r); } } return 0; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_1:: keep_node_balanced ( node*& t ) { // make a reference to the current node so we don't have to dereference // a pointer a bunch of times node& tree = *t; // if tree does not need to be balanced then return false if (tree.balance == 0) return false; // if tree needs to be rotated left if (tree.balance == 2) { if (tree.right->balance >= 0) rotate_left(t); else double_rotate_left(t); } // else if the tree needs to be rotated right else if (tree.balance == -2) { if (tree.left->balance <= 0) rotate_right(t); else double_rotate_right(t); } if (t->balance == 0) return true; else return false; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > unsigned long binary_search_tree_kernel_1:: get_count ( const domain& d, node* tree_root ) const { if (tree_root != 0) { if (comp(d , tree_root->d)) { // go left return get_count(d,tree_root->left); } else if (comp(tree_root->d , d)) { // go right return get_count(d,tree_root->right); } else { // go left and right to look for more matches return get_count(d,tree_root->left) + get_count(d,tree_root->right) + 1; } } return 0; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BINARY_SEARCH_TREE_KERNEl_1_ ================================================ FILE: dlib/binary_search_tree/binary_search_tree_kernel_2.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_2_ #define DLIB_BINARY_SEARCH_TREE_KERNEl_2_ #include "binary_search_tree_kernel_abstract.h" #include "../algs.h" #include "../interfaces/map_pair.h" #include "../interfaces/enumerable.h" #include "../interfaces/remover.h" #include "../serialize.h" #include namespace dlib { template < typename domain, typename range, typename mem_manager, typename compare = std::less > class binary_search_tree_kernel_2 : public enumerable >, public asc_pair_remover { /*! INITIAL VALUE NIL == pointer to a node that represents a leaf tree_size == 0 tree_root == NIL at_start == true current_element == 0 CONVENTION current_element_valid() == (current_element != 0) if (current_element_valid()) then element() == current_element->d and current_element->r at_start_ == at_start() tree_size == size() NIL == pointer to a node that represents a leaf if (tree_size != 0) tree_root == pointer to the root node of the binary search tree else tree_root == NIL tree_root->color == black Every leaf is black and all leafs are the NIL node. The number of black nodes in any path from the root to a leaf is the same. for all nodes: { - left points to the left subtree or NIL if there is no left subtree - right points to the right subtree or NIL if there is no right subtree - parent points to the parent node or NIL if the node is the root - ordering of nodes is determined by comparing each node's d member - all elements in a left subtree are <= the node - all elements in a right subtree are >= the node - color == red or black - if (color == red) - the node's children are black } !*/ class node { public: node* left; node* right; node* parent; domain d; range r; char color; }; class mpair : public map_pair { public: const domain* d; range* r; const domain& key( ) const { return *d; } const range& value( ) const { return *r; } range& value( ) { return *r; } }; const static char red = 0; const static char black = 1; public: typedef domain domain_type; typedef range range_type; typedef compare compare_type; typedef mem_manager mem_manager_type; binary_search_tree_kernel_2( ) : NIL(pool.allocate()), tree_size(0), tree_root(NIL), current_element(0), at_start_(true) { NIL->color = black; NIL->left = 0; NIL->right = 0; NIL->parent = 0; } virtual ~binary_search_tree_kernel_2( ); inline void clear( ); inline short height ( ) const; inline unsigned long count ( const domain& d ) const; inline void add ( domain& d, range& r ); void remove ( const domain& d, domain& d_copy, range& r ); void destroy ( const domain& d ); void remove_any ( domain& d, range& r ); inline const range* operator[] ( const domain& item ) const; inline range* operator[] ( const domain& item ); inline void swap ( binary_search_tree_kernel_2& item ); // functions from the enumerable interface inline size_t size ( ) const; bool at_start ( ) const; inline void reset ( ) const; bool current_element_valid ( ) const; const map_pair& element ( ) const; map_pair& element ( ); bool move_next ( ) const; void remove_last_in_order ( domain& d, range& r ); void remove_current_element ( domain& d, range& r ); void position_enumerator ( const domain& d ) const; private: inline void rotate_left ( node* t ); /*! requires - t != NIL - t->right != NIL ensures - performs a left rotation around t and its right child !*/ inline void rotate_right ( node* t ); /*! requires - t != NIL - t->left != NIL ensures - performs a right rotation around t and its left child !*/ inline void double_rotate_right ( node* t ); /*! requires - t != NIL - t->left != NIL - t->left->right != NIL - double_rotate_right() is only called in fix_after_add() ensures - performs a left rotation around t->left - then performs a right rotation around t !*/ inline void double_rotate_left ( node* t ); /*! requires - t != NIL - t->right != NIL - t->right->left != NIL - double_rotate_left() is only called in fix_after_add() ensures - performs a right rotation around t->right - then performs a left rotation around t !*/ void remove_biggest_element_in_tree ( node* t, domain& d, range& r ); /*! requires - t != NIL (i.e. there must be something in the tree to remove) ensures - the biggest node in t has been removed - the biggest node element in t has been put into #d and #r - #t is still a binary search tree !*/ bool remove_least_element_in_tree ( node* t, domain& d, range& r ); /*! requires - t != NIL (i.e. there must be something in the tree to remove) ensures - the least node in t has been removed - the least node element in t has been put into #d and #r - #t is still a binary search tree - if (the node that was removed was the one pointed to by current_element) then - returns true - else - returns false !*/ void add_to_tree ( node* t, domain& d, range& r ); /*! requires - t != NIL ensures - d and r are now in #t - there is a mapping from d to r in #t - #d and #r have initial values for their types - #t is still a binary search tree !*/ void remove_from_tree ( node* t, const domain& d, domain& d_copy, range& r ); /*! requires - return_reference(t,d) != 0 ensures - #d_copy is equivalent to d - the first element in t equivalent to d that is encountered when searching down the tree from t has been removed and swapped into #d_copy. Also, the associated range element has been removed and swapped into #r. - if (the node that got removed wasn't current_element) then - adjusts the current_element pointer if the data in the node that it points to gets moved. - else - the value of current_element is now invalid - #t is still a binary search tree !*/ void remove_from_tree ( node* t, const domain& d ); /*! requires - return_reference(t,d) != 0 ensures - an element in t equivalent to d has been removed - #t is still a binary search tree !*/ const range* return_reference ( const node* t, const domain& d ) const; /*! ensures - if (there is a domain element equivalent to d in t) then - returns a pointer to the element in the range equivalent to d - else - returns 0 !*/ range* return_reference ( node* t, const domain& d ); /*! ensures - if (there is a domain element equivalent to d in t) then - returns a pointer to the element in the range equivalent to d - else - returns 0 !*/ void fix_after_add ( node* t ); /*! requires - t == pointer to the node just added - t->color == red - t->parent != NIL (t must not be the root) - fix_after_add() is only called after a new node has been added to t ensures - fixes any deviations from the CONVENTION caused by adding a node !*/ void fix_after_remove ( node* t ); /*! requires - t == pointer to the only child of the node that was spliced out - fix_after_remove() is only called after a node has been removed from t - the color of the spliced out node was black ensures - fixes any deviations from the CONVENTION causes by removing a node !*/ short tree_height ( node* t ) const; /*! ensures - returns the number of nodes in the longest path from the root of the tree to a leaf !*/ void delete_tree ( node* t ); /*! requires - t == root of binary search tree - t != NIL ensures - deletes all nodes in t except for NIL !*/ unsigned long get_count ( const domain& item, node* tree_root ) const; /*! requires - tree_root == the root of a binary search tree or NIL ensures - if (tree_root == NIL) then - returns 0 - else - returns the number of elements in tree_root that are equivalent to item !*/ // data members typename mem_manager::template rebind::other pool; node* NIL; unsigned long tree_size; node* tree_root; mutable node* current_element; mutable bool at_start_; mutable mpair p; compare comp; // restricted functions binary_search_tree_kernel_2(binary_search_tree_kernel_2&); binary_search_tree_kernel_2& operator=(binary_search_tree_kernel_2&); }; template < typename domain, typename range, typename mem_manager, typename compare > inline void swap ( binary_search_tree_kernel_2& a, binary_search_tree_kernel_2& b ) { a.swap(b); } template < typename domain, typename range, typename mem_manager, typename compare > void deserialize ( binary_search_tree_kernel_2& item, std::istream& in ) { try { item.clear(); unsigned long size; deserialize(size,in); domain d; range r; for (unsigned long i = 0; i < size; ++i) { deserialize(d,in); deserialize(r,in); item.add(d,r); } } catch (serialization_error& e) { item.clear(); throw serialization_error(e.info + "\n while deserializing object of type binary_search_tree_kernel_2"); } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > binary_search_tree_kernel_2:: ~binary_search_tree_kernel_2 ( ) { if (tree_root != NIL) delete_tree(tree_root); pool.deallocate(NIL); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: clear ( ) { if (tree_size > 0) { delete_tree(tree_root); tree_root = NIL; tree_size = 0; } // reset the enumerator reset(); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > size_t binary_search_tree_kernel_2:: size ( ) const { return tree_size; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > short binary_search_tree_kernel_2:: height ( ) const { return tree_height(tree_root); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > unsigned long binary_search_tree_kernel_2:: count ( const domain& item ) const { return get_count(item,tree_root); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: add ( domain& d, range& r ) { if (tree_size == 0) { tree_root = pool.allocate(); tree_root->color = black; tree_root->left = NIL; tree_root->right = NIL; tree_root->parent = NIL; exchange(tree_root->d,d); exchange(tree_root->r,r); } else { add_to_tree(tree_root,d,r); } ++tree_size; // reset the enumerator reset(); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: remove ( const domain& d, domain& d_copy, range& r ) { remove_from_tree(tree_root,d,d_copy,r); --tree_size; // reset the enumerator reset(); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: destroy ( const domain& item ) { remove_from_tree(tree_root,item); --tree_size; // reset the enumerator reset(); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: remove_any ( domain& d, range& r ) { remove_least_element_in_tree(tree_root,d,r); --tree_size; // reset the enumerator reset(); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > range* binary_search_tree_kernel_2:: operator[] ( const domain& d ) { return return_reference(tree_root,d); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > const range* binary_search_tree_kernel_2:: operator[] ( const domain& d ) const { return return_reference(tree_root,d); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: swap ( binary_search_tree_kernel_2& item ) { pool.swap(item.pool); exchange(p,item.p); exchange(comp,item.comp); node* tree_root_temp = item.tree_root; unsigned long tree_size_temp = item.tree_size; node* const NIL_temp = item.NIL; node* current_element_temp = item.current_element; bool at_start_temp = item.at_start_; item.tree_root = tree_root; item.tree_size = tree_size; item.NIL = NIL; item.current_element = current_element; item.at_start_ = at_start_; tree_root = tree_root_temp; tree_size = tree_size_temp; NIL = NIL_temp; current_element = current_element_temp; at_start_ = at_start_temp; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: remove_last_in_order ( domain& d, range& r ) { remove_biggest_element_in_tree(tree_root,d,r); --tree_size; // reset the enumerator reset(); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: remove_current_element ( domain& d, range& r ) { node* t = current_element; move_next(); remove_from_tree(t,t->d,d,r); --tree_size; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: position_enumerator ( const domain& d ) const { // clear the enumerator state and make sure the stack is empty reset(); at_start_ = false; node* t = tree_root; node* parent = NIL; bool went_left = false; while (t != NIL) { if ( comp(d , t->d )) { // if item is on the left then look in left parent = t; t = t->left; went_left = true; } else if (comp(t->d , d)) { // if item is on the right then look in right parent = t; t = t->right; went_left = false; } else { current_element = t; return; } } // if we didn't find any matches but there might be something after the // d in this tree. if (parent != NIL) { current_element = parent; // if we went left from this node then this node is the next // biggest. if (went_left) { return; } else { move_next(); } } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // enumerable function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_2:: at_start ( ) const { return at_start_; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: reset ( ) const { at_start_ = true; current_element = 0; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_2:: current_element_valid ( ) const { return (current_element != 0); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > const map_pair& binary_search_tree_kernel_2:: element ( ) const { p.d = &(current_element->d); p.r = &(current_element->r); return p; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > map_pair& binary_search_tree_kernel_2:: element ( ) { p.d = &(current_element->d); p.r = &(current_element->r); return p; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_2:: move_next ( ) const { // if we haven't started iterating yet if (at_start_) { at_start_ = false; if (tree_size == 0) { return false; } else { // find the first element in the tree current_element = tree_root; node* temp = current_element->left; while (temp != NIL) { current_element = temp; temp = current_element->left; } return true; } } else { if (current_element == 0) { return false; } else { bool went_up; // true if we went up the tree from a child node to parent bool from_left = false; // true if we went up and were coming from a left child node // find the next element in the tree if (current_element->right != NIL) { // go right and down current_element = current_element->right; went_up = false; } else { went_up = true; node* parent = current_element->parent; if (parent == NIL) { // in this case we have iterated over all the element of the tree current_element = 0; return false; } from_left = (parent->left == current_element); // go up to parent current_element = parent; } while (true) { if (went_up) { if (from_left) { // in this case we have found the next node break; } else { // we should go up node* parent = current_element->parent; from_left = (parent->left == current_element); current_element = parent; if (current_element == NIL) { // in this case we have iterated over all the elements // in the tree current_element = 0; return false; } } } else { // we just went down to a child node if (current_element->left != NIL) { // go left went_up = false; current_element = current_element->left; } else { // if there is no left child then we have found the next node break; } } } return true; } } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // private member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: delete_tree ( node* t ) { if (t->left != NIL) delete_tree(t->left); if (t->right != NIL) delete_tree(t->right); pool.deallocate(t); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: rotate_left ( node* t ) { // perform the rotation node* temp = t->right; t->right = temp->left; if (temp->left != NIL) temp->left->parent = t; temp->left = t; temp->parent = t->parent; if (t == tree_root) tree_root = temp; else { // if t was on the left if (t->parent->left == t) t->parent->left = temp; else t->parent->right = temp; } t->parent = temp; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: rotate_right ( node* t ) { // perform the rotation node* temp = t->left; t->left = temp->right; if (temp->right != NIL) temp->right->parent = t; temp->right = t; temp->parent = t->parent; if (t == tree_root) tree_root = temp; else { // if t is a left child if (t->parent->left == t) t->parent->left = temp; else t->parent->right = temp; } t->parent = temp; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: double_rotate_right ( node* t ) { // preform the rotation node& temp = *(t->left->right); t->left = temp.right; temp.right->parent = t; temp.left->parent = temp.parent; temp.parent->right = temp.left; temp.parent->parent = &temp; temp.right = t; temp.left = temp.parent; temp.parent = t->parent; if (tree_root == t) tree_root = &temp; else { // t is a left child if (t->parent->left == t) t->parent->left = &temp; else t->parent->right = &temp; } t->parent = &temp; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: double_rotate_left ( node* t ) { // preform the rotation node& temp = *(t->right->left); t->right = temp.left; temp.left->parent = t; temp.right->parent = temp.parent; temp.parent->left = temp.right; temp.parent->parent = &temp; temp.left = t; temp.right = temp.parent; temp.parent = t->parent; if (tree_root == t) tree_root = &temp; else { // t is a left child if (t->parent->left == t) t->parent->left = &temp; else t->parent->right = &temp; } t->parent = &temp; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: remove_biggest_element_in_tree ( node* t, domain& d, range& r ) { node* next = t->right; node* child; // the child node of the one we will slice out if (next == NIL) { // need to determine if t is a right or left child if (t->parent->right == t) child = t->parent->right = t->left; else child = t->parent->left = t->left; // update tree_root if necessary if (t == tree_root) tree_root = child; } else { // find the least node do { t = next; next = next->right; } while (next != NIL); // t is a right child child = t->parent->right = t->left; } // swap the item from this node into d and r exchange(d,t->d); exchange(r,t->r); // plug hole right by removing this node child->parent = t->parent; // keep the red-black properties true if (t->color == black) fix_after_remove(child); // free the memory for this removed node pool.deallocate(t); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > bool binary_search_tree_kernel_2:: remove_least_element_in_tree ( node* t, domain& d, range& r ) { node* next = t->left; node* child; // the child node of the one we will slice out if (next == NIL) { // need to determine if t is a left or right child if (t->parent->left == t) child = t->parent->left = t->right; else child = t->parent->right = t->right; // update tree_root if necessary if (t == tree_root) tree_root = child; } else { // find the least node do { t = next; next = next->left; } while (next != NIL); // t is a left child child = t->parent->left = t->right; } // swap the item from this node into d and r exchange(d,t->d); exchange(r,t->r); // plug hole left by removing this node child->parent = t->parent; // keep the red-black properties true if (t->color == black) fix_after_remove(child); bool rvalue = (t == current_element); // free the memory for this removed node pool.deallocate(t); return rvalue; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: add_to_tree ( node* t, domain& d, range& r ) { // parent of the current node node* parent; // find a place to add node while (true) { parent = t; // if item should be put on the left then go left if (comp(d , t->d)) { t = t->left; if (t == NIL) { t = parent->left = pool.allocate(); break; } } // if item should be put on the right then go right else { t = t->right; if (t == NIL) { t = parent->right = pool.allocate(); break; } } } // t is now the node where we will add item and // parent is the parent of t t->parent = parent; t->left = NIL; t->right = NIL; t->color = red; exchange(t->d,d); exchange(t->r,r); // keep the red-black properties true fix_after_add(t); } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: remove_from_tree ( node* t, const domain& d, domain& d_copy, range& r ) { while (true) { if ( comp(d , t->d) ) { // if item is on the left then look in left t = t->left; } else if (comp(t->d , d)) { // if item is on the right then look in right t = t->right; } else { // found the node we want to remove // swap out the item into d_copy and r exchange(d_copy,t->d); exchange(r,t->r); if (t->left == NIL) { // if there is no left subtree node* parent = t->parent; // plug hole with right subtree // if t is on the left if (parent->left == t) parent->left = t->right; else parent->right = t->right; t->right->parent = parent; // update tree_root if necessary if (t == tree_root) tree_root = t->right; if (t->color == black) fix_after_remove(t->right); // delete old node pool.deallocate(t); } else if (t->right == NIL) { // if there is no right subtree node* parent = t->parent; // plug hole with left subtree if (parent->left == t) parent->left = t->left; else parent->right = t->left; t->left->parent = parent; // update tree_root if necessary if (t == tree_root) tree_root = t->left; if (t->color == black) fix_after_remove(t->left); // delete old node pool.deallocate(t); } else { // if there is both a left and right subtree // get an element to fill this node now that its been swapped into // item_copy if (remove_least_element_in_tree(t->right,t->d,t->r)) { // the node removed was the one pointed to by current_element so we // need to update it so that it points to the right spot. current_element = t; } } // quit loop break; } } } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: remove_from_tree ( node* t, const domain& d ) { while (true) { if ( comp(d , t->d) ) { // if item is on the left then look in left t = t->left; } else if (comp(t->d , d)) { // if item is on the right then look in right t = t->right; } else { // found the node we want to remove if (t->left == NIL) { // if there is no left subtree node* parent = t->parent; // plug hole with right subtree if (parent->left == t) parent->left = t->right; else parent->right = t->right; t->right->parent = parent; // update tree_root if necessary if (t == tree_root) tree_root = t->right; if (t->color == black) fix_after_remove(t->right); // delete old node pool.deallocate(t); } else if (t->right == NIL) { // if there is no right subtree node* parent = t->parent; // plug hole with left subtree if (parent->left == t) parent->left = t->left; else parent->right = t->left; t->left->parent = parent; // update tree_root if necessary if (t == tree_root) tree_root = t->left; if (t->color == black) fix_after_remove(t->left); // delete old node pool.deallocate(t); } else { // if there is both a left and right subtree // get an element to fill this node now that its been swapped into // item_copy remove_least_element_in_tree(t->right,t->d,t->r); } // quit loop break; } } } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > range* binary_search_tree_kernel_2:: return_reference ( node* t, const domain& d ) { while (t != NIL) { if ( comp(d , t->d )) { // if item is on the left then look in left t = t->left; } else if (comp(t->d , d)) { // if item is on the right then look in right t = t->right; } else { // if it's found then return a reference to it return &(t->r); } } return 0; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > const range* binary_search_tree_kernel_2:: return_reference ( const node* t, const domain& d ) const { while (t != NIL) { if ( comp(d , t->d) ) { // if item is on the left then look in left t = t->left; } else if (comp(t->d , d)) { // if item is on the right then look in right t = t->right; } else { // if it's found then return a reference to it return &(t->r); } } return 0; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: fix_after_add ( node* t ) { while (t->parent->color == red) { node& grandparent = *(t->parent->parent); // if both t's parent and its sibling are red if (grandparent.left->color == grandparent.right->color) { grandparent.color = red; grandparent.left->color = black; grandparent.right->color = black; t = &grandparent; } else { // if t is a left child if (t == t->parent->left) { // if t's parent is a left child if (t->parent == grandparent.left) { grandparent.color = red; grandparent.left->color = black; rotate_right(&grandparent); } // if t's parent is a right child else { t->color = black; grandparent.color = red; double_rotate_left(&grandparent); } } // if t is a right child else { // if t's parent is a left child if (t->parent == grandparent.left) { t->color = black; grandparent.color = red; double_rotate_right(&grandparent); } // if t's parent is a right child else { grandparent.color = red; grandparent.right->color = black; rotate_left(&grandparent); } } break; } } tree_root->color = black; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > void binary_search_tree_kernel_2:: fix_after_remove ( node* t ) { while (t != tree_root && t->color == black) { if (t->parent->left == t) { node* sibling = t->parent->right; if (sibling->color == red) { sibling->color = black; t->parent->color = red; rotate_left(t->parent); sibling = t->parent->right; } if (sibling->left->color == black && sibling->right->color == black) { sibling->color = red; t = t->parent; } else { if (sibling->right->color == black) { sibling->left->color = black; sibling->color = red; rotate_right(sibling); sibling = t->parent->right; } sibling->color = t->parent->color; t->parent->color = black; sibling->right->color = black; rotate_left(t->parent); t = tree_root; } } else { node* sibling = t->parent->left; if (sibling->color == red) { sibling->color = black; t->parent->color = red; rotate_right(t->parent); sibling = t->parent->left; } if (sibling->left->color == black && sibling->right->color == black) { sibling->color = red; t = t->parent; } else { if (sibling->left->color == black) { sibling->right->color = black; sibling->color = red; rotate_left(sibling); sibling = t->parent->left; } sibling->color = t->parent->color; t->parent->color = black; sibling->left->color = black; rotate_right(t->parent); t = tree_root; } } } t->color = black; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > short binary_search_tree_kernel_2:: tree_height ( node* t ) const { if (t == NIL) return 0; short height1 = tree_height(t->left); short height2 = tree_height(t->right); if (height1 > height2) return height1 + 1; else return height2 + 1; } // ---------------------------------------------------------------------------------------- template < typename domain, typename range, typename mem_manager, typename compare > unsigned long binary_search_tree_kernel_2:: get_count ( const domain& d, node* tree_root ) const { if (tree_root != NIL) { if (comp(d , tree_root->d)) { // go left return get_count(d,tree_root->left); } else if (comp(tree_root->d , d)) { // go right return get_count(d,tree_root->right); } else { // go left and right to look for more matches return get_count(d,tree_root->left) + get_count(d,tree_root->right) + 1; } } return 0; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BINARY_SEARCH_TREE_KERNEl_2_ ================================================ FILE: dlib/binary_search_tree/binary_search_tree_kernel_abstract.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_ #ifdef DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_ #include "../interfaces/map_pair.h" #include "../interfaces/enumerable.h" #include "../interfaces/remover.h" #include "../serialize.h" #include "../algs.h" #include namespace dlib { template < typename domain, typename range, typename mem_manager = default_memory_manager, typename compare = std::less > class binary_search_tree : public enumerable >, public asc_pair_remover { /*! REQUIREMENTS ON domain domain must be comparable by compare where compare is a functor compatible with std::less and domain is swappable by a global swap() and domain must have a default constructor REQUIREMENTS ON range range is swappable by a global swap() and range must have a default constructor REQUIREMENTS ON mem_manager must be an implementation of memory_manager/memory_manager_kernel_abstract.h or must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h mem_manager::type can be set to anything. POINTERS AND REFERENCES TO INTERNAL DATA swap(), count(), height(), and operator[] functions do not invalidate pointers or references to internal data. position_enumerator() invalidates pointers or references to data returned by element() and only by element() (i.e. pointers and references returned by operator[] are still valid). All other functions have no such guarantees. INITIAL VALUE size() == 0 height() == 0 ENUMERATION ORDER The enumerator will iterate over the domain (and each associated range element) elements in ascending order according to the compare functor. (i.e. the elements are enumerated in sorted order) WHAT THIS OBJECT REPRESENTS this object represents a data dictionary that is built on top of some kind of binary search tree. It maps objects of type domain to objects of type range. Also note that unless specified otherwise, no member functions of this object throw exceptions. NOTE: definition of equivalent: a is equivalent to b if a < b == false and b < a == false !*/ public: typedef domain domain_type; typedef range range_type; typedef compare compare_type; typedef mem_manager mem_manager_type; binary_search_tree( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc or any exception thrown by domain's or range's constructor. !*/ virtual ~binary_search_tree( ); /*! ensures - all memory associated with *this has been released !*/ void clear( ); /*! ensures - #*this has its initial value throws - std::bad_alloc or any exception thrown by domain's or range's constructor. if this exception is thrown then *this is unusable until clear() is called and succeeds !*/ short height ( ) const; /*! ensures - returns the number of elements in the longest path from the root of the tree to a leaf !*/ unsigned long count ( const domain& d ) const; /*! ensures - returns the number of elements in the domain of *this that are equivalent to d !*/ void add ( domain& d, range& r ); /*! requires - &d != &r (i.e. d and r cannot be the same variable) ensures - adds a mapping between d and r to *this - if (count(d) == 0) then - #*(*this)[d] == r - else - #(*this)[d] != 0 - #d and #r have initial values for their types - #count(d) == count(d) + 1 - #at_start() == true - #size() == size() + 1 throws - std::bad_alloc or any exception thrown by domain's or range's constructor. if add() throws then it has no effect !*/ void remove ( const domain& d, domain& d_copy, range& r ); /*! requires - (*this)[d] != 0 - &d != &r (i.e. d and r cannot be the same variable) - &d != &d_copy (i.e. d and d_copy cannot be the same variable) - &r != &d_copy (i.e. r and d_copy cannot be the same variable) ensures - some element in the domain of *this that is equivalent to d has been removed and swapped into #d_copy. Additionally, its associated range element has been removed and swapped into #r. - #count(d) == count(d) - 1 - #size() == size() - 1 - #at_start() == true !*/ void destroy ( const domain& d ); /*! requires - (*this)[d] != 0 ensures - an element in the domain of *this equivalent to d has been removed. The element in the range of *this associated with d has also been removed. - #count(d) == count(d) - 1 - #size() == size() - 1 - #at_start() == true !*/ void remove_last_in_order ( domain& d, range& r ); /*! requires - size() > 0 ensures - the last/biggest (according to the compare functor) element in the domain of *this has been removed and swapped into #d. The element in the range of *this associated with #d has also been removed and swapped into #r. - #count(#d) == count(#d) - 1 - #size() == size() - 1 - #at_start() == true !*/ void remove_current_element ( domain& d, range& r ); /*! requires - current_element_valid() == true ensures - the current element given by element() has been removed and swapped into d and r. - #d == element().key() - #r == element().value() - #count(#d) == count(#d) - 1 - #size() == size() - 1 - moves the enumerator to the next element. If element() was the last element in enumeration order then #current_element_valid() == false and #at_start() == false. !*/ void position_enumerator ( const domain& d ) const; /*! ensures - #at_start() == false - if (count(d) > 0) then - #element().key() == d - else if (there are any items in the domain of *this that are bigger than d according to the compare functor) then - #element().key() == the smallest item in the domain of *this that is bigger than d according to the compare functor. - else - #current_element_valid() == false !*/ const range* operator[] ( const domain& d ) const; /*! ensures - if (there is an element in the domain equivalent to d) then - returns a pointer to an element in the range of *this that is associated with an element in the domain of *this equivalent to d. - else - returns 0 !*/ range* operator[] ( const domain& d ); /*! ensures - if (there is an element in the domain equivalent to d) then - returns a pointer to an element in the range of *this that is associated with an element in the domain of *this equivalent to d. - else - returns 0 !*/ void swap ( binary_search_tree& item ); /*! ensures - swaps *this and item !*/ private: // restricted functions binary_search_tree(binary_search_tree&); binary_search_tree& operator=(binary_search_tree&); }; template < typename domain, typename range, typename mem_manager, typename compare > inline void swap ( binary_search_tree& a, binary_search_tree& b ) { a.swap(b); } /*! provides a global swap function !*/ template < typename domain, typename range, typename mem_manager, typename compare > void deserialize ( binary_search_tree& item, std::istream& in ); /*! provides deserialization support !*/ } #endif // DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_ ================================================ FILE: dlib/binary_search_tree/binary_search_tree_kernel_c.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_C_ #define DLIB_BINARY_SEARCH_TREE_KERNEl_C_ #include "../interfaces/map_pair.h" #include "binary_search_tree_kernel_abstract.h" #include "../algs.h" #include "../assert.h" namespace dlib { template < typename bst_base > class binary_search_tree_kernel_c : public bst_base { typedef typename bst_base::domain_type domain; typedef typename bst_base::range_type range; public: binary_search_tree_kernel_c () {} void remove ( const domain& d, domain& d_copy, range& r ); void destroy ( const domain& d ); void add ( domain& d, range& r ); void remove_any ( domain& d, range& r ); const map_pair& element( ) const { DLIB_CASSERT(this->current_element_valid() == true, "\tconst map_pair& binary_search_tree::element() const" << "\n\tyou can't access the current element if it doesn't exist" << "\n\tthis: " << this ); return bst_base::element(); } map_pair& element( ) { DLIB_CASSERT(this->current_element_valid() == true, "\tmap_pair& binary_search_tree::element()" << "\n\tyou can't access the current element if it doesn't exist" << "\n\tthis: " << this ); return bst_base::element(); } void remove_last_in_order ( domain& d, range& r ); void remove_current_element ( domain& d, range& r ); }; template < typename bst_base > inline void swap ( binary_search_tree_kernel_c& a, binary_search_tree_kernel_c& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename bst_base > void binary_search_tree_kernel_c:: add ( domain& d, range& r ) { DLIB_CASSERT( static_cast(&d) != static_cast(&r), "\tvoid binary_search_tree::add" << "\n\tyou can't call add() and give the same object to both parameters." << "\n\tthis: " << this << "\n\t&d: " << &d << "\n\t&r: " << &r << "\n\tsize(): " << this->size() ); bst_base::add(d,r); } // ---------------------------------------------------------------------------------------- template < typename bst_base > void binary_search_tree_kernel_c:: destroy ( const domain& d ) { DLIB_CASSERT(this->operator[](d) != 0, "\tvoid binary_search_tree::destroy" << "\n\tthe element must be in the tree for it to be removed" << "\n\tthis: " << this << "\n\t&d: " << &d ); bst_base::destroy(d); } // ---------------------------------------------------------------------------------------- template < typename bst_base > void binary_search_tree_kernel_c:: remove ( const domain& d, domain& d_copy, range& r ) { DLIB_CASSERT(this->operator[](d) != 0 && (static_cast(&d) != static_cast(&d_copy)) && (static_cast(&d) != static_cast(&r)) && (static_cast(&r) != static_cast(&d_copy)), "\tvoid binary_search_tree::remove" << "\n\tthe element must be in the tree for it to be removed" << "\n\tthis: " << this << "\n\t&d: " << &d << "\n\t&d_copy: " << &d_copy << "\n\t&r: " << &r ); bst_base::remove(d,d_copy,r); } // ---------------------------------------------------------------------------------------- template < typename bst_base > void binary_search_tree_kernel_c:: remove_any( domain& d, range& r ) { DLIB_CASSERT(this->size() != 0 && (static_cast(&d) != static_cast(&r)), "\tvoid binary_search_tree::remove_any" << "\n\ttree must not be empty if something is going to be removed" << "\n\tthis: " << this << "\n\t&d: " << &d << "\n\t&r: " << &r ); bst_base::remove_any(d,r); } // ---------------------------------------------------------------------------------------- template < typename bst_base > void binary_search_tree_kernel_c:: remove_last_in_order ( domain& d, range& r ) { DLIB_CASSERT(this->size() > 0, "\tvoid binary_search_tree::remove_last_in_order()" << "\n\tyou can't remove an element if it doesn't exist" << "\n\tthis: " << this ); bst_base::remove_last_in_order(d,r); } // ---------------------------------------------------------------------------------------- template < typename bst_base > void binary_search_tree_kernel_c:: remove_current_element ( domain& d, range& r ) { DLIB_CASSERT(this->current_element_valid() == true, "\tvoid binary_search_tree::remove_current_element()" << "\n\tyou can't remove the current element if it doesn't exist" << "\n\tthis: " << this ); bst_base::remove_current_element(d,r); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BINARY_SEARCH_TREE_KERNEl_C_ ================================================ FILE: dlib/binary_search_tree.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BINARY_SEARCH_TREe_ #define DLIB_BINARY_SEARCH_TREe_ #include "binary_search_tree/binary_search_tree_kernel_1.h" #include "binary_search_tree/binary_search_tree_kernel_2.h" #include "binary_search_tree/binary_search_tree_kernel_c.h" #include "algs.h" #include namespace dlib { template < typename domain, typename range, typename mem_manager = default_memory_manager, typename compare = std::less > class binary_search_tree { binary_search_tree() {} public: //----------- kernels --------------- // kernel_1a typedef binary_search_tree_kernel_1 kernel_1a; typedef binary_search_tree_kernel_c kernel_1a_c; // kernel_2a typedef binary_search_tree_kernel_2 kernel_2a; typedef binary_search_tree_kernel_c kernel_2a_c; }; } #endif // DLIB_BINARY_SEARCH_TREe_ ================================================ FILE: dlib/bit_stream/bit_stream_kernel_1.cpp ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BIT_STREAM_KERNEL_1_CPp_ #define DLIB_BIT_STREAM_KERNEL_1_CPp_ #include "bit_stream_kernel_1.h" #include "../algs.h" #include namespace dlib { inline void swap ( bit_stream_kernel_1& a, bit_stream_kernel_1& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- void bit_stream_kernel_1:: clear ( ) { if (write_mode) { write_mode = false; // flush output buffer if (buffer_size > 0) { buffer <<= 8 - buffer_size; osp->write(reinterpret_cast(&buffer),1); } } else read_mode = false; } // ---------------------------------------------------------------------------------------- void bit_stream_kernel_1:: set_input_stream ( std::istream& is ) { isp = &is; read_mode = true; buffer_size = 0; } // ---------------------------------------------------------------------------------------- void bit_stream_kernel_1:: set_output_stream ( std::ostream& os ) { osp = &os; write_mode = true; buffer_size = 0; } // ---------------------------------------------------------------------------------------- void bit_stream_kernel_1:: close ( ) { if (write_mode) { write_mode = false; // flush output buffer if (buffer_size > 0) { buffer <<= 8 - buffer_size; osp->write(reinterpret_cast(&buffer),1); } } else read_mode = false; } // ---------------------------------------------------------------------------------------- bool bit_stream_kernel_1:: is_in_write_mode ( ) const { return write_mode; } // ---------------------------------------------------------------------------------------- bool bit_stream_kernel_1:: is_in_read_mode ( ) const { return read_mode; } // ---------------------------------------------------------------------------------------- void bit_stream_kernel_1:: write ( int bit ) { // flush buffer if necessary if (buffer_size == 8) { buffer <<= 8 - buffer_size; if (osp->rdbuf()->sputn(reinterpret_cast(&buffer),1) == 0) { throw std::ios_base::failure("error occurred in the bit_stream object"); } buffer_size = 0; } ++buffer_size; buffer <<= 1; buffer += static_cast(bit); } // ---------------------------------------------------------------------------------------- bool bit_stream_kernel_1:: read ( int& bit ) { // get new byte if necessary if (buffer_size == 0) { if (isp->rdbuf()->sgetn(reinterpret_cast(&buffer), 1) == 0) { // if we didn't read anything then return false return false; } buffer_size = 8; } // put the most significant bit from buffer into bit bit = static_cast(buffer >> 7); // shift out the bit that was just read buffer <<= 1; --buffer_size; return true; } // ---------------------------------------------------------------------------------------- void bit_stream_kernel_1:: swap ( bit_stream_kernel_1& item ) { std::istream* isp_temp = item.isp; std::ostream* osp_temp = item.osp; bool write_mode_temp = item.write_mode; bool read_mode_temp = item.read_mode; unsigned char buffer_temp = item.buffer; unsigned short buffer_size_temp = item.buffer_size; item.isp = isp; item.osp = osp; item.write_mode = write_mode; item.read_mode = read_mode; item.buffer = buffer; item.buffer_size = buffer_size; isp = isp_temp; osp = osp_temp; write_mode = write_mode_temp; read_mode = read_mode_temp; buffer = buffer_temp; buffer_size = buffer_size_temp; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BIT_STREAM_KERNEL_1_CPp_ ================================================ FILE: dlib/bit_stream/bit_stream_kernel_1.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BIT_STREAM_KERNEl_1_ #define DLIB_BIT_STREAM_KERNEl_1_ #include "bit_stream_kernel_abstract.h" #include namespace dlib { class bit_stream_kernel_1 { /*! INITIAL VALUE write_mode == false read_mode == false CONVENTION write_mode == is_in_write_mode() read_mode == is_in_read_mode() if (write_mode) { osp == pointer to an ostream object buffer == the low order bits of buffer are the bits to be written buffer_size == the number of low order bits in buffer that are bits that should be written the lowest order bit is the last bit entered by the user } if (read_mode) { isp == pointer to an istream object buffer == the high order bits of buffer are the bits waiting to be read by the user buffer_size == the number of high order bits in buffer that are bits that are waiting to be read the highest order bit is the next bit to give to the user } !*/ public: bit_stream_kernel_1 ( ) : write_mode(false), read_mode(false) {} virtual ~bit_stream_kernel_1 ( ) {} void clear ( ); void set_input_stream ( std::istream& is ); void set_output_stream ( std::ostream& os ); void close ( ); inline bool is_in_write_mode ( ) const; inline bool is_in_read_mode ( ) const; inline void write ( int bit ); bool read ( int& bit ); void swap ( bit_stream_kernel_1& item ); private: // member data std::istream* isp; std::ostream* osp; bool write_mode; bool read_mode; unsigned char buffer; unsigned short buffer_size; // restricted functions bit_stream_kernel_1(bit_stream_kernel_1&); // copy constructor bit_stream_kernel_1& operator=(bit_stream_kernel_1&); // assignment operator }; inline void swap ( bit_stream_kernel_1& a, bit_stream_kernel_1& b ); // ---------------------------------------------------------------------------------------- } #ifdef NO_MAKEFILE #include "bit_stream_kernel_1.cpp" #endif #endif // DLIB_BIT_STREAM_KERNEl_1_ ================================================ FILE: dlib/bit_stream/bit_stream_kernel_abstract.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_BIT_STREAM_KERNEl_ABSTRACT_ #ifdef DLIB_BIT_STREAM_KERNEl_ABSTRACT_ #include namespace dlib { class bit_stream { /*! INITIAL VALUE is_in_write_mode() == false is_in_read_mode() == false WHAT THIS OBJECT REPRESENTS this object is a middle man between a user and the iostream classes. it allows single bits to be read/written easily to/from the iostream classes BUFFERING: This object will only read/write single bytes at a time from/to the iostream objects. Any buffered bits still in the bit_stream object when it is closed or destructed are lost if it is in read mode. If it is in write mode then any remaining bits are guaranteed to be written to the output stream by the time it is closed or destructed. !*/ public: bit_stream ( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc !*/ virtual ~bit_stream ( ); /*! ensures - all memory associated with *this has been released !*/ void clear ( ); /*! ensures - #*this has its initial value throws - std::bad_alloc if this exception is thrown then *this is unusable until clear() is called and succeeds !*/ void set_input_stream ( std::istream& is ); /*! requires - is_in_write_mode() == false - is_in_read_mode() == false - is is ready to give input ensures - #is_in_write_mode() == false - #is_in_read_mode() == true - #*this will now be reading from is throws - std::bad_alloc !*/ void set_output_stream ( std::ostream& os ); /*! requires - is_in_write_mode() == false - is_in_read_mode() == false - os is ready to take output ensures - #is_in_write_mode() == true - #is_in_read_mode() == false - #*this will now write to os throws - std::bad_alloc !*/ void close ( ); /*! requires - is_in_write_mode() == true || is_in_read_mode() == true ensures - #is_in_write_mode() == false - #is_in_read_mode() == false !*/ bool is_in_write_mode ( ) const; /*! ensures - returns true if *this is associated with an output stream object - returns false otherwise !*/ bool is_in_read_mode ( ) const; /*! ensures - returns true if *this is associated with an input stream object - returns false otherwise !*/ void write ( int bit ); /*! requires - is_in_write_mode() == true - bit == 0 || bit == 1 ensures - bit will be written to the ostream object associated with *this throws - std::ios_base::failure if (there was a problem writing to the output stream) then this exception will be thrown. #*this will be unusable until clear() is called and succeeds - any other exception if this exception is thrown then #*this is unusable until clear() is called and succeeds !*/ bool read ( int& bit ); /*! requires - is_in_read_mode() == true ensures - the next bit has been read and placed into #bit - returns true if the read was successful, else false (ex. false if EOF has been reached) throws - any exception if this exception is thrown then #*this is unusable until clear() is called and succeeds !*/ void swap ( bit_stream& item ); /*! ensures - swaps *this and item !*/ private: // restricted functions bit_stream(bit_stream&); // copy constructor bit_stream& operator=(bit_stream&); // assignment operator }; inline void swap ( bit_stream& a, bit_stream& b ) { a.swap(b); } /*! provides a global swap function !*/ } #endif // DLIB_BIT_STREAM_KERNEl_ABSTRACT_ ================================================ FILE: dlib/bit_stream/bit_stream_kernel_c.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BIT_STREAM_KERNEl_C_ #define DLIB_BIT_STREAM_KERNEl_C_ #include "bit_stream_kernel_abstract.h" #include "../algs.h" #include "../assert.h" #include namespace dlib { template < typename bit_stream_base // implements bit_stream/bit_stream_kernel_abstract.h > class bit_stream_kernel_c : public bit_stream_base { public: void set_input_stream ( std::istream& is ); void set_output_stream ( std::ostream& os ); void close ( ); void write ( int bit ); bool read ( int& bit ); }; template < typename bit_stream_base > inline void swap ( bit_stream_kernel_c& a, bit_stream_kernel_c& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename bit_stream_base > void bit_stream_kernel_c:: set_input_stream ( std::istream& is ) { // make sure requires clause is not broken DLIB_CASSERT(( this->is_in_write_mode() == false ) && ( this->is_in_read_mode() == false ), "\tvoid bit_stream::set_intput_stream" << "\n\tbit_stream must not be in write or read mode" << "\n\tthis: " << this ); // call the real function bit_stream_base::set_input_stream(is); } // ---------------------------------------------------------------------------------------- template < typename bit_stream_base > void bit_stream_kernel_c:: set_output_stream ( std::ostream& os ) { // make sure requires clause is not broken DLIB_CASSERT(( this->is_in_write_mode() == false ) && ( this->is_in_read_mode() == false ), "\tvoid bit_stream::set_output_stream" << "\n\tbit_stream must not be in write or read mode" << "\n\tthis: " << this ); // call the real function bit_stream_base::set_output_stream(os); } // ---------------------------------------------------------------------------------------- template < typename bit_stream_base > void bit_stream_kernel_c:: close ( ) { // make sure requires clause is not broken DLIB_CASSERT(( this->is_in_write_mode() == true ) || ( this->is_in_read_mode() == true ), "\tvoid bit_stream::close" << "\n\tyou can't close a bit_stream that isn't open" << "\n\tthis: " << this ); // call the real function bit_stream_base::close(); } // ---------------------------------------------------------------------------------------- template < typename bit_stream_base > void bit_stream_kernel_c:: write ( int bit ) { // make sure requires clause is not broken DLIB_CASSERT(( this->is_in_write_mode() == true ) && ( bit == 0 || bit == 1 ), "\tvoid bit_stream::write" << "\n\tthe bit stream bust be in write mode and bit must be either 1 or 0" << "\n\tis_in_write_mode() == " << this->is_in_write_mode() << "\n\tbit == " << bit << "\n\tthis: " << this ); // call the real function bit_stream_base::write(bit); } // ---------------------------------------------------------------------------------------- template < typename bit_stream_base > bool bit_stream_kernel_c:: read ( int& bit ) { // make sure requires clause is not broken DLIB_CASSERT(( this->is_in_read_mode() == true ), "\tbool bit_stream::read" << "\n\tyou can't read from a bit_stream that isn't in read mode" << "\n\tthis: " << this ); // call the real function return bit_stream_base::read(bit); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BIT_STREAM_KERNEl_C_ ================================================ FILE: dlib/bit_stream/bit_stream_multi_1.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BIT_STREAM_MULTi_1_ #define DLIB_BIT_STREAM_MULTi_1_ #include "bit_stream_multi_abstract.h" namespace dlib { template < typename bit_stream_base > class bit_stream_multi_1 : public bit_stream_base { public: void multi_write ( unsigned long data, int num_to_write ); int multi_read ( unsigned long& data, int num_to_read ); }; template < typename bit_stream_base > inline void swap ( bit_stream_multi_1& a, bit_stream_multi_1& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename bit_stream_base > void bit_stream_multi_1:: multi_write ( unsigned long data, int num_to_write ) { // move the first bit into the most significant position data <<= 32 - num_to_write; for (int i = 0; i < num_to_write; ++i) { // write the first bit from data this->write(static_cast(data >> 31)); // shift the next bit into position data <<= 1; } } // ---------------------------------------------------------------------------------------- template < typename bit_stream_base > int bit_stream_multi_1:: multi_read ( unsigned long& data, int num_to_read ) { int bit, i; data = 0; for (i = 0; i < num_to_read; ++i) { // get a bit if (this->read(bit) == false) break; // shift data to make room for this new bit data <<= 1; // put bit into the least significant position in data data += static_cast(bit); } return i; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BIT_STREAM_MULTi_1_ ================================================ FILE: dlib/bit_stream/bit_stream_multi_abstract.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_BIT_STREAM_MULTi_ABSTRACT_ #ifdef DLIB_BIT_STREAM_MULTi_ABSTRACT_ #include "bit_stream_kernel_abstract.h" namespace dlib { template < typename bit_stream_base > class bit_stream_multi : public bit_stream_base { /*! REQUIREMENTS ON BIT_STREAM_BASE it is an implementation of bit_stream/bit_stream_kernel_abstract.h WHAT THIS EXTENSION DOES FOR BIT_STREAM this gives a bit_stream object the ability to read/write multible bits at a time !*/ public: void multi_write ( unsigned long data, int num_to_write ); /*! requires - is_in_write_mode() == true - 0 <= num_to_write <= 32 ensures - num_to_write low order bits from data will be written to the ostream - object associated with *this example: if data is 10010 then the bits will be written in the order 1,0,0,1,0 !*/ int multi_read ( unsigned long& data, int num_to_read ); /*! requires - is_in_read_mode() == true - 0 <= num_to_read <= 32 ensures - tries to read num_to_read bits into the low order end of #data example: if the incoming bits were 10010 then data would end up with 10010 as its low order bits - all of the bits in #data not filled in by multi_read() are zero - returns the number of bits actually read into #data !*/ }; template < typename bit_stream_base > inline void swap ( bit_stream_multi& a, bit_stream_multi& b ) { a.swap(b); } /*! provides a global swap function !*/ } #endif // DLIB_BIT_STREAM_MULTi_ABSTRACT_ ================================================ FILE: dlib/bit_stream/bit_stream_multi_c.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BIT_STREAM_MULTi_C_ #define DLIB_BIT_STREAM_MULTi_C_ #include "bit_stream_multi_abstract.h" #include "../algs.h" #include "../assert.h" namespace dlib { template < typename bit_stream_base // implements bit_stream/bit_stream_multi_abstract.h > class bit_stream_multi_c : public bit_stream_base { public: void multi_write ( unsigned long data, int num_to_write ); int multi_read ( unsigned long& data, int num_to_read ); }; template < typename bit_stream_base > inline void swap ( bit_stream_multi_c& a, bit_stream_multi_c& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename bit_stream_base > void bit_stream_multi_c:: multi_write ( unsigned long data, int num_to_write ) { // make sure requires clause is not broken DLIB_CASSERT( (this->is_in_write_mode() == true) && (num_to_write >= 0 && num_to_write <=32), "\tvoid bit_stream::write" << "\n\tthe bit stream bust be in write mode and" << "\n\tnum_to_write must be between 0 and 32 inclusive" << "\n\tnum_to_write == " << num_to_write << "\n\tis_in_write_mode() == " << this->is_in_write_mode() << "\n\tthis: " << this ); // call the real function bit_stream_base::multi_write(data,num_to_write); } // ---------------------------------------------------------------------------------------- template < typename bit_stream_base > int bit_stream_multi_c:: multi_read ( unsigned long& data, int num_to_read ) { // make sure requires clause is not broken DLIB_CASSERT(( this->is_in_read_mode() == true && ( num_to_read >= 0 && num_to_read <=32 ) ), "\tvoid bit_stream::read" << "\n\tyou can't read from a bit_stream that isn't in read mode and" << "\n\tnum_to_read must be between 0 and 32 inclusive" << "\n\tnum_to_read == " << num_to_read << "\n\tis_in_read_mode() == " << this->is_in_read_mode() << "\n\tthis: " << this ); // call the real function return bit_stream_base::multi_read(data,num_to_read); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BIT_STREAM_MULTi_C_ ================================================ FILE: dlib/bit_stream.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BIT_STREAm_ #define DLIB_BIT_STREAm_ #include "bit_stream/bit_stream_kernel_1.h" #include "bit_stream/bit_stream_kernel_c.h" #include "bit_stream/bit_stream_multi_1.h" #include "bit_stream/bit_stream_multi_c.h" namespace dlib { class bit_stream { bit_stream() {} public: //----------- kernels --------------- // kernel_1a typedef bit_stream_kernel_1 kernel_1a; typedef bit_stream_kernel_c kernel_1a_c; //---------- extensions ------------ // multi_1 extend kernel_1a typedef bit_stream_multi_1 multi_1a; typedef bit_stream_multi_c > multi_1a_c; }; } #endif // DLIB_BIT_STREAm_ ================================================ FILE: dlib/bits/c++config.h ================================================ #include "../dlib_include_path_tutorial.txt" ================================================ FILE: dlib/bound_function_pointer/bound_function_pointer_kernel_1.h ================================================ // Copyright (C) 2008 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_ #define DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_ #include "../algs.h" #include "../member_function_pointer.h" #include "bound_function_pointer_kernel_abstract.h" namespace dlib { // ---------------------------------------------------------------------------------------- namespace bfp1_helpers { template struct strip { typedef T type; }; template struct strip { typedef T type; }; // ------------------------------------------------------------------------------------ class bound_function_helper_base_base { public: virtual ~bound_function_helper_base_base(){} virtual void call() const = 0; virtual bool is_set() const = 0; virtual void clone(void* ptr) const = 0; }; // ------------------------------------------------------------------------------------ template class bound_function_helper_base : public bound_function_helper_base_base { public: bound_function_helper_base():arg1(0), arg2(0), arg3(0), arg4(0) {} typename strip::type* arg1; typename strip::type* arg2; typename strip::type* arg3; typename strip::type* arg4; member_function_pointer mfp; }; // ---------------- template class bound_function_helper : public bound_function_helper_base { public: void call() const { (*fp)(*this->arg1, *this->arg2, *this->arg3, *this->arg4); } typename strip::type* fp; }; template class bound_function_helper : public bound_function_helper_base { public: void call() const { if (this->mfp) this->mfp(*this->arg1, *this->arg2, *this->arg3, *this->arg4); else if (fp) fp(*this->arg1, *this->arg2, *this->arg3, *this->arg4); } void (*fp)(T1, T2, T3, T4); }; // ---------------- template class bound_function_helper : public bound_function_helper_base { public: void call() const { (*fp)(); } typename strip::type* fp; }; template <> class bound_function_helper : public bound_function_helper_base { public: void call() const { if (this->mfp) this->mfp(); else if (fp) fp(); } void (*fp)(); }; // ---------------- template class bound_function_helper : public bound_function_helper_base { public: void call() const { (*fp)(*this->arg1); } typename strip::type* fp; }; template class bound_function_helper : public bound_function_helper_base { public: void call() const { if (this->mfp) this->mfp(*this->arg1); else if (fp) fp(*this->arg1); } void (*fp)(T1); }; // ---------------- template class bound_function_helper : public bound_function_helper_base { public: void call() const { (*fp)(*this->arg1, *this->arg2); } typename strip::type* fp; }; template class bound_function_helper : public bound_function_helper_base { public: void call() const { if (this->mfp) this->mfp(*this->arg1, *this->arg2); else if (fp) fp(*this->arg1, *this->arg2); } void (*fp)(T1, T2); }; // ---------------- template class bound_function_helper : public bound_function_helper_base { public: void call() const { (*fp)(*this->arg1, *this->arg2, *this->arg3); } typename strip::type* fp; }; template class bound_function_helper : public bound_function_helper_base { public: void call() const { if (this->mfp) this->mfp(*this->arg1, *this->arg2, *this->arg3); else if (fp) fp(*this->arg1, *this->arg2, *this->arg3); } void (*fp)(T1, T2, T3); }; // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ template class bound_function_helper_T : public T { public: bound_function_helper_T(){ this->fp = 0;} bool is_set() const { return this->fp != 0 || this->mfp.is_set(); } template void safe_clone(stack_based_memory_block& buf) { // This is here just to validate the assumption that our block of memory we have made // in bf_memory is the right size to store the data for this object. If you // get a compiler error on this line then email me :) COMPILE_TIME_ASSERT(sizeof(bound_function_helper_T) <= mem_size); clone(buf.get()); } void clone (void* ptr) const { bound_function_helper_T* p = new(ptr) bound_function_helper_T(); p->arg1 = this->arg1; p->arg2 = this->arg2; p->arg3 = this->arg3; p->arg4 = this->arg4; p->fp = this->fp; p->mfp = this->mfp; } }; } // ---------------------------------------------------------------------------------------- class bound_function_pointer { typedef bfp1_helpers::bound_function_helper_T > bf_null_type; public: // These typedefs are here for backwards compatibility with previous versions of // dlib. typedef bound_function_pointer kernel_1a; typedef bound_function_pointer kernel_1a_c; bound_function_pointer ( ) { bf_null_type().safe_clone(bf_memory); } bound_function_pointer ( const bound_function_pointer& item ) { item.bf()->clone(bf_memory.get()); } ~bound_function_pointer() { destroy_bf_memory(); } bound_function_pointer& operator= ( const bound_function_pointer& item ) { bound_function_pointer(item).swap(*this); return *this; } void clear ( ) { bound_function_pointer().swap(*this); } bool is_set ( ) const { return bf()->is_set(); } void swap ( bound_function_pointer& item ) { // make a temp copy of item bound_function_pointer temp(item); // destory the stuff in item item.destroy_bf_memory(); // copy *this into item bf()->clone(item.bf_memory.get()); // destory the stuff in this destroy_bf_memory(); // copy temp into *this temp.bf()->clone(bf_memory.get()); } void operator() ( ) const { // make sure requires clause is not broken DLIB_ASSERT(is_set() == true , "\tvoid bound_function_pointer::operator()" << "\n\tYou must call set() before you can use this function" << "\n\tthis: " << this ); bf()->call(); } private: struct dummy{ void nonnull() {}}; typedef void (dummy::*safe_bool)(); public: operator safe_bool () const { return is_set() ? &dummy::nonnull : 0; } bool operator!() const { return !is_set(); } // ------------------------------------------- // set function object overloads // ------------------------------------------- template void set ( F& function_object ) { COMPILE_TIME_ASSERT(std::is_function::value == false); COMPILE_TIME_ASSERT(std::is_pointer::value == false); using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.fp = &function_object; temp.safe_clone(bf_memory); } template void set ( F& function_object, A1& arg1 ) { COMPILE_TIME_ASSERT(std::is_function::value == false); COMPILE_TIME_ASSERT(std::is_pointer::value == false); using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.fp = &function_object; temp.safe_clone(bf_memory); } template void set ( F& function_object, A1& arg1, A2& arg2 ) { COMPILE_TIME_ASSERT(std::is_function::value == false); COMPILE_TIME_ASSERT(std::is_pointer::value == false); using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.arg2 = &arg2; temp.fp = &function_object; temp.safe_clone(bf_memory); } template void set ( F& function_object, A1& arg1, A2& arg2, A3& arg3 ) { COMPILE_TIME_ASSERT(std::is_function::value == false); COMPILE_TIME_ASSERT(std::is_pointer::value == false); using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.arg2 = &arg2; temp.arg3 = &arg3; temp.fp = &function_object; temp.safe_clone(bf_memory); } template void set ( F& function_object, A1& arg1, A2& arg2, A3& arg3, A4& arg4 ) { COMPILE_TIME_ASSERT(std::is_function::value == false); COMPILE_TIME_ASSERT(std::is_pointer::value == false); using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.arg2 = &arg2; temp.arg3 = &arg3; temp.arg4 = &arg4; temp.fp = &function_object; temp.safe_clone(bf_memory); } // ------------------------------------------- // set mfp overloads // ------------------------------------------- template void set ( T& object, void (T::*funct)() ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.mfp.set(object,funct); temp.safe_clone(bf_memory); } template void set ( const T& object, void (T::*funct)()const ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.mfp.set(object,funct); temp.safe_clone(bf_memory); } // ------------------------------------------- template void set ( T& object, void (T::*funct)(T1), A1& arg1 ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.mfp.set(object,funct); temp.safe_clone(bf_memory); } template void set ( const T& object, void (T::*funct)(T1)const, A1& arg1 ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.mfp.set(object,funct); temp.safe_clone(bf_memory); } // ---------------- template void set ( T& object, void (T::*funct)(T1, T2), A1& arg1, A2& arg2 ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.arg2 = &arg2; temp.mfp.set(object,funct); temp.safe_clone(bf_memory); } template void set ( const T& object, void (T::*funct)(T1, T2)const, A1& arg1, A2& arg2 ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.arg2 = &arg2; temp.mfp.set(object,funct); temp.safe_clone(bf_memory); } // ---------------- template void set ( T& object, void (T::*funct)(T1, T2, T3), A1& arg1, A2& arg2, A3& arg3 ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.arg2 = &arg2; temp.arg3 = &arg3; temp.mfp.set(object,funct); temp.safe_clone(bf_memory); } template void set ( const T& object, void (T::*funct)(T1, T2, T3)const, A1& arg1, A2& arg2, A3& arg3 ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.arg2 = &arg2; temp.arg3 = &arg3; temp.mfp.set(object,funct); temp.safe_clone(bf_memory); } // ---------------- template void set ( T& object, void (T::*funct)(T1, T2, T3, T4), A1& arg1, A2& arg2, A3& arg3, A4& arg4 ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.arg2 = &arg2; temp.arg3 = &arg3; temp.arg4 = &arg4; temp.mfp.set(object,funct); temp.safe_clone(bf_memory); } template void set ( const T& object, void (T::*funct)(T1, T2, T3, T4)const, A1& arg1, A2& arg2, A3& arg3, A4& arg4 ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.arg2 = &arg2; temp.arg3 = &arg3; temp.arg4 = &arg4; temp.mfp.set(object,funct); temp.safe_clone(bf_memory); } // ------------------------------------------- // set fp overloads // ------------------------------------------- void set ( void (*funct)() ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.fp = funct; temp.safe_clone(bf_memory); } template void set ( void (*funct)(T1), A1& arg1 ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.fp = funct; temp.safe_clone(bf_memory); } template void set ( void (*funct)(T1, T2), A1& arg1, A2& arg2 ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.arg2 = &arg2; temp.fp = funct; temp.safe_clone(bf_memory); } template void set ( void (*funct)(T1, T2, T3), A1& arg1, A2& arg2, A3& arg3 ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.arg2 = &arg2; temp.arg3 = &arg3; temp.fp = funct; temp.safe_clone(bf_memory); } template void set ( void (*funct)(T1, T2, T3, T4), A1& arg1, A2& arg2, A3& arg3, A4& arg4 ) { using namespace bfp1_helpers; destroy_bf_memory(); typedef bound_function_helper_T > bf_helper_type; bf_helper_type temp; temp.arg1 = &arg1; temp.arg2 = &arg2; temp.arg3 = &arg3; temp.arg4 = &arg4; temp.fp = funct; temp.safe_clone(bf_memory); } // ------------------------------------------- private: stack_based_memory_block bf_memory; void destroy_bf_memory ( ) { // Honestly, this probably doesn't even do anything but I'm putting // it here just for good measure. bf()->~bound_function_helper_base_base(); } bfp1_helpers::bound_function_helper_base_base* bf () { return static_cast(bf_memory.get()); } const bfp1_helpers::bound_function_helper_base_base* bf () const { return static_cast(bf_memory.get()); } }; // ---------------------------------------------------------------------------------------- inline void swap ( bound_function_pointer& a, bound_function_pointer& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_ ================================================ FILE: dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h ================================================ // Copyright (C) 2008 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_ #ifdef DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_ namespace dlib { // ---------------------------------------------------------------------------------------- class bound_function_pointer { /*! INITIAL VALUE is_set() == false WHAT THIS OBJECT REPRESENTS This object represents a function with all its arguments bound to specific objects. For example: void test(int& var) { var = var+1; } bound_function_pointer funct; int a = 4; funct.set(test,a); // bind the variable a to the first argument of the test() function // at this point a == 4 funct(); // after funct() is called a == 5 !*/ public: bound_function_pointer ( ); /*! ensures - #*this is properly initialized !*/ bound_function_pointer( const bound_function_pointer& item ); /*! ensures - *this == item !*/ ~bound_function_pointer ( ); /*! ensures - any resources associated with *this have been released !*/ bound_function_pointer& operator=( const bound_function_pointer& item ); /*! ensures - *this == item !*/ void clear( ); /*! ensures - #*this has its initial value !*/ bool is_set ( ) const; /*! ensures - if (this->set() has been called) then - returns true - else - returns false !*/ operator some_undefined_pointer_type ( ) const; /*! ensures - if (is_set()) then - returns a non 0 value - else - returns a 0 value !*/ bool operator! ( ) const; /*! ensures - returns !is_set() !*/ void operator () ( ) const; /*! requires - is_set() == true ensures - calls the bound function on the object(s) specified by the last call to this->set() throws - any exception thrown by the function specified by the previous call to this->set(). If any of these exceptions are thrown then the call to this function will have no effect on *this. !*/ void swap ( bound_function_pointer& item ); /*! ensures - swaps *this and item !*/ // ---------------------- template void set ( F& function_object ); /*! requires - function_object() is a valid expression ensures - #is_set() == true - calls to this->operator() will call function_object() (This seems pointless but it is a useful base case) !*/ template < typename T> void set ( T& object, void (T::*funct)() ); /*! requires - funct == a valid member function pointer for class T ensures - #is_set() == true - calls to this->operator() will call (object.*funct)() !*/ template < typename T> void set ( const T& object, void (T::*funct)()const ); /*! requires - funct == a valid bound function pointer for class T ensures - #is_set() == true - calls to this->operator() will call (object.*funct)() !*/ void set ( void (*funct)() ); /*! requires - funct == a valid function pointer ensures - #is_set() == true - calls to this->operator() will call funct() !*/ // ---------------------- template void set ( F& function_object, A1& arg1 ); /*! requires - function_object(arg1) is a valid expression ensures - #is_set() == true - calls to this->operator() will call function_object(arg1) !*/ template < typename T, typename T1, typename A1 > void set ( T& object, void (T::*funct)(T1), A1& arg1 ); /*! requires - funct == a valid member function pointer for class T ensures - #is_set() == true - calls to this->operator() will call (object.*funct)(arg1) !*/ template < typename T, typename T1, typename A1 > void set ( const T& object, void (T::*funct)(T1)const, A1& arg1 ); /*! requires - funct == a valid bound function pointer for class T ensures - #is_set() == true - calls to this->operator() will call (object.*funct)(arg1) !*/ template void set ( void (*funct)(T1), A1& arg1 ); /*! requires - funct == a valid function pointer ensures - #is_set() == true - calls to this->operator() will call funct(arg1) !*/ // ---------------------- template void set ( F& function_object, A1& arg1, A2& arg2 ); /*! requires - function_object(arg1,arg2) is a valid expression ensures - #is_set() == true - calls to this->operator() will call function_object(arg1,arg2) !*/ template < typename T, typename T1, typename A1, typename T2, typename A2> void set ( T& object, void (T::*funct)(T1,T2), A1& arg1, A2& arg2 ); /*! requires - funct == a valid member function pointer for class T ensures - #is_set() == true - calls to this->operator() will call (object.*funct)(arg1,arg2) !*/ template < typename T, typename T1, typename A1, typename T2, typename A2> void set ( const T& object, void (T::*funct)(T1,T2)const, A1& arg1, A2& arg2 ); /*! requires - funct == a valid bound function pointer for class T ensures - #is_set() == true - calls to this->operator() will call (object.*funct)(arg1,arg2) !*/ template void set ( void (*funct)(T1,T2), A1& arg1, A2& arg2 ); /*! requires - funct == a valid function pointer ensures - #is_set() == true - calls to this->operator() will call funct(arg1,arg2) !*/ // ---------------------- template void set ( F& function_object, A1& arg1, A2& arg2, A3& arg3 ); /*! requires - function_object(arg1,arg2,arg3) is a valid expression ensures - #is_set() == true - calls to this->operator() will call function_object(arg1,arg2,arg3) !*/ template < typename T, typename T1, typename A1, typename T2, typename A2, typename T3, typename A3> void set ( T& object, void (T::*funct)(T1,T2,T3), A1& arg1, A2& arg2, A3& arg3 ); /*! requires - funct == a valid member function pointer for class T ensures - #is_set() == true - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3) !*/ template < typename T, typename T1, typename A1, typename T2, typename A2, typename T3, typename A3> void set ( const T& object, void (T::*funct)(T1,T2,T3)const, A1& arg1, A2& arg2, A3& arg3 ); /*! requires - funct == a valid bound function pointer for class T ensures - #is_set() == true - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3) !*/ template void set ( void (*funct)(T1,T2,T3), A1& arg1, A2& arg2, A3& arg3 ); /*! requires - funct == a valid function pointer ensures - #is_set() == true - calls to this->operator() will call funct(arg1,arg2,arg3) !*/ // ---------------------- template void set ( F& function_object, A1& arg1, A2& arg2, A3& arg3, A4& arg4 ); /*! requires - function_object(arg1,arg2,arg3,arg4) is a valid expression ensures - #is_set() == true - calls to this->operator() will call function_object(arg1,arg2,arg3,arg4) !*/ template < typename T, typename T1, typename A1, typename T2, typename A2, typename T3, typename A3, typename T4, typename A4> void set ( T& object, void (T::*funct)(T1,T2,T3,T4), A1& arg1, A2& arg2, A3& arg3, A4& arg4 ); /*! requires - funct == a valid member function pointer for class T ensures - #is_set() == true - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3,arg4) !*/ template < typename T, typename T1, typename A1, typename T2, typename A2, typename T3, typename A3, typename T4, typename A4> void set ( const T& object, void (T::*funct)(T1,T2,T3,T4)const, A1& arg1, A2& arg2, A3& arg3, A4& arg4 ); /*! requires - funct == a valid bound function pointer for class T ensures - #is_set() == true - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3,arg4) !*/ template void set ( void (*funct)(T1,T2,T3,T4), A1& arg1, A2& arg2, A3& arg3, A4& arg4 ); /*! requires - funct == a valid function pointer ensures - #is_set() == true - calls to this->operator() will call funct(arg1,arg2,arg3,arg4) !*/ }; // ---------------------------------------------------------------------------------------- inline void swap ( bound_function_pointer& a, bound_function_pointer& b ) { a.swap(b); } /*! provides a global swap function !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_ ================================================ FILE: dlib/bound_function_pointer.h ================================================ // Copyright (C) 2008 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BOUND_FUNCTION_POINTEr_ #define DLIB_BOUND_FUNCTION_POINTEr_ #include "bound_function_pointer/bound_function_pointer_kernel_1.h" #endif // DLIB_BOUND_FUNCTION_POINTEr_ ================================================ FILE: dlib/bridge/bridge.h ================================================ // Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BRIDGe_Hh_ #define DLIB_BRIDGe_Hh_ #include #include #include #include "bridge_abstract.h" #include "../pipe.h" #include "../threads.h" #include "../serialize.h" #include "../sockets.h" #include "../sockstreambuf.h" #include "../logger.h" #include "../algs.h" namespace dlib { // ---------------------------------------------------------------------------------------- struct connect_to_ip_and_port { connect_to_ip_and_port ( const std::string& ip_, unsigned short port_ ): ip(ip_), port(port_) { // make sure requires clause is not broken DLIB_ASSERT(is_ip_address(ip) && port != 0, "\t connect_to_ip_and_port()" << "\n\t Invalid inputs were given to this function" << "\n\t ip: " << ip << "\n\t port: " << port << "\n\t this: " << this ); } private: friend class bridge; const std::string ip; const unsigned short port; }; inline connect_to_ip_and_port connect_to ( const network_address& addr ) { // make sure requires clause is not broken DLIB_ASSERT(addr.port != 0, "\t connect_to_ip_and_port()" << "\n\t The TCP port to connect to can't be 0." << "\n\t addr.port: " << addr.port ); if (is_ip_address(addr.host_address)) { return connect_to_ip_and_port(addr.host_address, addr.port); } else { std::string ip; if(hostname_to_ip(addr.host_address,ip)) throw socket_error(ERESOLVE,"unable to resolve '" + addr.host_address + "' in connect_to()"); return connect_to_ip_and_port(ip, addr.port); } } struct listen_on_port { listen_on_port( unsigned short port_ ) : port(port_) { // make sure requires clause is not broken DLIB_ASSERT( port != 0, "\t listen_on_port()" << "\n\t Invalid inputs were given to this function" << "\n\t port: " << port << "\n\t this: " << this ); } private: friend class bridge; const unsigned short port; }; template struct bridge_transmit_decoration { bridge_transmit_decoration ( pipe_type& p_ ) : p(p_) {} private: friend class bridge; pipe_type& p; }; template bridge_transmit_decoration transmit ( pipe_type& p) { return bridge_transmit_decoration(p); } template struct bridge_receive_decoration { bridge_receive_decoration ( pipe_type& p_ ) : p(p_) {} private: friend class bridge; pipe_type& p; }; template bridge_receive_decoration receive ( pipe_type& p) { return bridge_receive_decoration(p); } // ---------------------------------------------------------------------------------------- struct bridge_status { bridge_status() : is_connected(false), foreign_port(0){} bool is_connected; unsigned short foreign_port; std::string foreign_ip; }; inline void serialize ( const bridge_status& , std::ostream& ) { throw serialization_error("It is illegal to serialize bridge_status objects."); } inline void deserialize ( bridge_status& , std::istream& ) { throw serialization_error("It is illegal to serialize bridge_status objects."); } // ---------------------------------------------------------------------------------------- namespace impl_brns { class impl_bridge_base { public: virtual ~impl_bridge_base() {} virtual bridge_status get_bridge_status ( ) const = 0; }; template < typename transmit_pipe_type, typename receive_pipe_type > class impl_bridge : public impl_bridge_base, private noncopyable, private multithreaded_object { /*! CONVENTION - if (list) then - this object is supposed to be listening on the list object for incoming connections when not connected. - else - this object is supposed to be attempting to connect to ip:port when not connected. - get_bridge_status() == current_bs !*/ public: impl_bridge ( unsigned short listen_port, transmit_pipe_type* transmit_pipe_, receive_pipe_type* receive_pipe_ ) : s(m), receive_thread_active(false), transmit_thread_active(false), port(0), transmit_pipe(transmit_pipe_), receive_pipe(receive_pipe_), dlog("dlib.bridge"), keepalive_code(0), message_code(1) { int status = create_listener(list, listen_port); if (status == PORTINUSE) { std::ostringstream sout; sout << "Error, the port " << listen_port << " is already in use."; throw socket_error(EPORT_IN_USE, sout.str()); } else if (status == OTHER_ERROR) { throw socket_error("Unable to create listening socket for an unknown reason."); } register_thread(*this, &impl_bridge::transmit_thread); register_thread(*this, &impl_bridge::receive_thread); register_thread(*this, &impl_bridge::connect_thread); start(); } impl_bridge ( const std::string ip_, unsigned short port_, transmit_pipe_type* transmit_pipe_, receive_pipe_type* receive_pipe_ ) : s(m), receive_thread_active(false), transmit_thread_active(false), port(port_), ip(ip_), transmit_pipe(transmit_pipe_), receive_pipe(receive_pipe_), dlog("dlib.bridge"), keepalive_code(0), message_code(1) { register_thread(*this, &impl_bridge::transmit_thread); register_thread(*this, &impl_bridge::receive_thread); register_thread(*this, &impl_bridge::connect_thread); start(); } ~impl_bridge() { // tell the threads to terminate stop(); // save current pipe enabled status so we can restore it to however // it was before this destructor ran. bool transmit_enabled = true; bool receive_enabled = true; // make any calls blocked on a pipe return immediately. if (transmit_pipe) { transmit_enabled = transmit_pipe->is_dequeue_enabled(); transmit_pipe->disable_dequeue(); } if (receive_pipe) { receive_enabled = receive_pipe->is_enqueue_enabled(); receive_pipe->disable_enqueue(); } { auto_mutex lock(m); s.broadcast(); // Shutdown the connection if we have one. This will cause // all blocked I/O calls to return an error. if (con) con->shutdown(); } // wait for all the threads to terminate. wait(); if (transmit_pipe && transmit_enabled) transmit_pipe->enable_dequeue(); if (receive_pipe && receive_enabled) receive_pipe->enable_enqueue(); } bridge_status get_bridge_status ( ) const { auto_mutex lock(current_bs_mutex); return current_bs; } private: template std::enable_if_t::value> enqueue_bridge_status ( pipe_type* p, const bridge_status& status ) { if (p) { typename pipe_type::type temp(status); p->enqueue(temp); } } template std::enable_if_t::value> enqueue_bridge_status ( pipe_type* , const bridge_status& ) { } void connect_thread ( ) { while (!should_stop()) { auto_mutex lock(m); int status = OTHER_ERROR; if (list) { do { status = list->accept(con, 1000); } while (status == TIMEOUT && !should_stop()); } else { status = create_connection(con, port, ip); } if (should_stop()) break; if (status != 0) { // The last connection attempt failed. So pause for a little bit before making another attempt. s.wait_or_timeout(2000); continue; } dlog << LINFO << "Established new connection to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << "."; bridge_status temp_bs; { auto_mutex lock(current_bs_mutex); current_bs.is_connected = true; current_bs.foreign_port = con->get_foreign_port(); current_bs.foreign_ip = con->get_foreign_ip(); temp_bs = current_bs; } enqueue_bridge_status(receive_pipe, temp_bs); receive_thread_active = true; transmit_thread_active = true; s.broadcast(); // Wait for the transmit and receive threads to end before we continue. // This way we don't invalidate the con pointer while it is in use. while (receive_thread_active || transmit_thread_active) s.wait(); dlog << LINFO << "Closed connection to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << "."; { auto_mutex lock(current_bs_mutex); current_bs.is_connected = false; current_bs.foreign_port = con->get_foreign_port(); current_bs.foreign_ip = con->get_foreign_ip(); temp_bs = current_bs; } enqueue_bridge_status(receive_pipe, temp_bs); } } void receive_thread ( ) { while (true) { // wait until we have a connection { auto_mutex lock(m); while (!receive_thread_active && !should_stop()) { s.wait(); } if (should_stop()) break; } try { if (receive_pipe) { sockstreambuf buf(con); std::istream in(&buf); typename receive_pipe_type::type item; // This isn't necessary but doing it avoids a warning about // item being uninitialized sometimes. assign_zero_if_built_in_scalar_type(item); while (in.peek() != EOF) { unsigned char code; in.read((char*)&code, sizeof(code)); if (code == message_code) { deserialize(item, in); receive_pipe->enqueue(item); } } } else { // Since we don't have a receive pipe to put messages into we will // just read the bytes from the connection and ignore them. char buf[1000]; while (con->read(buf, sizeof(buf)) > 0) ; } } catch (std::bad_alloc& ) { dlog << LERROR << "std::bad_alloc thrown while deserializing message from " << con->get_foreign_ip() << ":" << con->get_foreign_port(); } catch (dlib::serialization_error& e) { dlog << LERROR << "dlib::serialization_error thrown while deserializing message from " << con->get_foreign_ip() << ":" << con->get_foreign_port() << ".\nThe exception error message is: \n" << e.what(); } catch (std::exception& e) { dlog << LERROR << "std::exception thrown while deserializing message from " << con->get_foreign_ip() << ":" << con->get_foreign_port() << ".\nThe exception error message is: \n" << e.what(); } con->shutdown(); auto_mutex lock(m); receive_thread_active = false; s.broadcast(); } auto_mutex lock(m); receive_thread_active = false; s.broadcast(); } void transmit_thread ( ) { while (true) { // wait until we have a connection { auto_mutex lock(m); while (!transmit_thread_active && !should_stop()) { s.wait(); } if (should_stop()) break; } try { sockstreambuf buf(con); std::ostream out(&buf); typename transmit_pipe_type::type item; // This isn't necessary but doing it avoids a warning about // item being uninitialized sometimes. assign_zero_if_built_in_scalar_type(item); while (out) { bool dequeue_timed_out = false; if (transmit_pipe ) { if (transmit_pipe->dequeue_or_timeout(item,1000)) { out.write((char*)&message_code, sizeof(message_code)); serialize(item, out); if (transmit_pipe->size() == 0) out.flush(); continue; } dequeue_timed_out = (transmit_pipe->is_enabled() && transmit_pipe->is_dequeue_enabled()); } // Pause for about a second. Note that we use a wait_or_timeout() call rather // than sleep() here because we want to wake up immediately if this object is // being destructed rather than hang for a second. if (!dequeue_timed_out) { auto_mutex lock(m); if (should_stop()) break; s.wait_or_timeout(1000); } // Just send the keepalive byte periodically so we can // tell if the connection is alive. out.write((char*)&keepalive_code, sizeof(keepalive_code)); out.flush(); } } catch (std::bad_alloc& ) { dlog << LERROR << "std::bad_alloc thrown while serializing message to " << con->get_foreign_ip() << ":" << con->get_foreign_port(); } catch (dlib::serialization_error& e) { dlog << LERROR << "dlib::serialization_error thrown while serializing message to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << ".\nThe exception error message is: \n" << e.what(); } catch (std::exception& e) { dlog << LERROR << "std::exception thrown while serializing message to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << ".\nThe exception error message is: \n" << e.what(); } con->shutdown(); auto_mutex lock(m); transmit_thread_active = false; s.broadcast(); } auto_mutex lock(m); transmit_thread_active = false; s.broadcast(); } mutex m; signaler s; bool receive_thread_active; bool transmit_thread_active; std::unique_ptr con; std::unique_ptr list; const unsigned short port; const std::string ip; transmit_pipe_type* const transmit_pipe; receive_pipe_type* const receive_pipe; logger dlog; const unsigned char keepalive_code; const unsigned char message_code; mutex current_bs_mutex; bridge_status current_bs; }; } // ---------------------------------------------------------------------------------------- class bridge : noncopyable { public: bridge () {} template < typename T, typename U, typename V > bridge ( T network_parameters, U pipe1, V pipe2 ) { reconfigure(network_parameters,pipe1,pipe2); } template < typename T, typename U> bridge ( T network_parameters, U pipe ) { reconfigure(network_parameters,pipe); } void clear ( ) { pimpl.reset(); } template < typename T, typename R > void reconfigure ( listen_on_port network_parameters, bridge_transmit_decoration transmit_pipe, bridge_receive_decoration receive_pipe ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } template < typename T, typename R > void reconfigure ( listen_on_port network_parameters, bridge_receive_decoration receive_pipe, bridge_transmit_decoration transmit_pipe ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } template < typename T > void reconfigure ( listen_on_port network_parameters, bridge_transmit_decoration transmit_pipe ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, &transmit_pipe.p, 0)); } template < typename R > void reconfigure ( listen_on_port network_parameters, bridge_receive_decoration receive_pipe ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, 0, &receive_pipe.p)); } template < typename T, typename R > void reconfigure ( connect_to_ip_and_port network_parameters, bridge_transmit_decoration transmit_pipe, bridge_receive_decoration receive_pipe ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } template < typename T, typename R > void reconfigure ( connect_to_ip_and_port network_parameters, bridge_receive_decoration receive_pipe, bridge_transmit_decoration transmit_pipe ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } template < typename R > void reconfigure ( connect_to_ip_and_port network_parameters, bridge_receive_decoration receive_pipe ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, 0, &receive_pipe.p)); } template < typename T > void reconfigure ( connect_to_ip_and_port network_parameters, bridge_transmit_decoration transmit_pipe ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, &transmit_pipe.p, 0)); } bridge_status get_bridge_status ( ) const { if (pimpl) return pimpl->get_bridge_status(); else return bridge_status(); } private: std::unique_ptr pimpl; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_BRIDGe_Hh_ ================================================ FILE: dlib/bridge/bridge_abstract.h ================================================ // Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_BRIDGe_ABSTRACT_ #ifdef DLIB_BRIDGe_ABSTRACT_ #include #include "../pipe/pipe_kernel_abstract.h" namespace dlib { // ---------------------------------------------------------------------------------------- struct connect_to_ip_and_port { connect_to_ip_and_port ( const std::string& ip, unsigned short port ); /*! requires - is_ip_address(ip) == true - port != 0 ensures - this object will represent a request to make a TCP connection to the given IP address and port number. !*/ }; connect_to_ip_and_port connect_to ( const network_address& addr ); /*! requires - addr.port != 0 ensures - converts the given network_address object into a connect_to_ip_and_port object. !*/ struct listen_on_port { listen_on_port( unsigned short port ); /*! requires - port != 0 ensures - this object will represent a request to listen on the given port number for incoming TCP connections. !*/ }; template < typename pipe_type > bridge_transmit_decoration transmit ( pipe_type& p ); /*! requires - pipe_type is some kind of dlib::pipe object - the objects in the pipe must be serializable ensures - Adds a type decoration to the given pipe, marking it as a transmit pipe, and then returns it. !*/ template < typename pipe_type > bridge_receive_decoration receive ( pipe_type& p ); /*! requires - pipe_type is some kind of dlib::pipe object - the objects in the pipe must be serializable ensures - Adds a type decoration to the given pipe, marking it as a receive pipe, and then returns it. !*/ // ---------------------------------------------------------------------------------------- struct bridge_status { /*! WHAT THIS OBJECT REPRESENTS This simple struct represents the state of a bridge object. A bridge is either connected or not. If it is connected then it is connected to a foreign host with an IP address and port number as indicated by this object. !*/ bridge_status( ); /*! ensures - #is_connected == false - #foreign_port == 0 - #foreign_ip == "" !*/ bool is_connected; unsigned short foreign_port; std::string foreign_ip; }; // ---------------------------------------------------------------------------------------- class bridge : noncopyable { /*! WHAT THIS OBJECT REPRESENTS This object is a tool for bridging a dlib::pipe object between two network connected applications. Note also that this object contains a dlib::logger object which will log various events taking place inside a bridge. If you want to see these log messages then enable the logger named "dlib.bridge". BRIDGE PROTOCOL DETAILS The bridge object creates a single TCP connection between two applications. Whenever it sends an object from a pipe over a TCP connection it sends a byte with the value 1 followed immediately by the serialized copy of the object from the pipe. The serialization is performed by calling the global serialize() function. Additionally, a bridge object will periodically send bytes with a value of 0 to ensure the TCP connection remains alive. These are just read and ignored. !*/ public: bridge ( ); /*! ensures - this object is properly initialized - #get_bridge_status().is_connected == false !*/ template bridge ( T network_parameters, U pipe1, V pipe2 ); /*! requires - T is of type connect_to_ip_and_port or listen_on_port - U and V are of type bridge_transmit_decoration or bridge_receive_decoration, however, U and V must be of different types (i.e. one is a receive type and another a transmit type). ensures - this object is properly initialized - performs: reconfigure(network_parameters, pipe1, pipe2) (i.e. using this constructor is identical to using the default constructor and then calling reconfigure()) !*/ template bridge ( T network_parameters, U pipe ); /*! requires - T is of type connect_to_ip_and_port or listen_on_port - U is of type bridge_transmit_decoration or bridge_receive_decoration. ensures - this object is properly initialized - performs: reconfigure(network_parameters, pipe) (i.e. using this constructor is identical to using the default constructor and then calling reconfigure()) !*/ ~bridge ( ); /*! ensures - blocks until all resources associated with this object have been destroyed. !*/ void clear ( ); /*! ensures - returns this object to its default constructed state. That is, it will be inactive, neither maintaining a connection nor attempting to acquire one. - Any active connections or listening sockets will be closed. !*/ bridge_status get_bridge_status ( ) const; /*! ensures - returns the current status of this bridge object. In particular, returns an object BS such that: - BS.is_connected == true if and only if the bridge has an active TCP connection to another computer. - if (BS.is_connected) then - BS.foreign_ip == the IP address of the remote host we are connected to. - BS.foreign_port == the port number on the remote host we are connected to. - else if (the bridge has previously been connected to a remote host but hasn't been reconfigured or cleared since) then - BS.foreign_ip == the IP address of the remote host we were connected to. - BS.foreign_port == the port number on the remote host we were connected to. - else - BS.foreign_ip == "" - BS.foreign_port == 0 !*/ template < typename T, typename R > void reconfigure ( listen_on_port network_parameters, bridge_transmit_decoration transmit_pipe, bridge_receive_decoration receive_pipe ); /*! ensures - This object will begin listening on the port specified by network_parameters for incoming TCP connections. Any previous bridge state is cleared out. - Onces a connection is established we will: - Stop accepting new connections. - Begin dequeuing objects from the transmit pipe and serializing them over the TCP connection. - Begin deserializing objects from the TCP connection and enqueueing them onto the receive pipe. - if (the current TCP connection is lost) then - This object goes back to listening for a new connection. - if (the receive pipe can contain bridge_status objects) then - Whenever the bridge's status changes the updated bridge_status will be enqueued onto the receive pipe unless the change was a TCP disconnect resulting from a user calling reconfigure(), clear(), or destructing this bridge. The status contents are defined by get_bridge_status(). throws - socket_error This exception is thrown if we are unable to open the listening socket. !*/ template < typename T, typename R > void reconfigure ( listen_on_port network_parameters, bridge_receive_decoration receive_pipe, bridge_transmit_decoration transmit_pipe ); /*! ensures - performs reconfigure(network_parameters, transmit_pipe, receive_pipe) !*/ template < typename T > void reconfigure ( listen_on_port network_parameters, bridge_transmit_decoration transmit_pipe ); /*! ensures - This function is identical to the above two reconfigure() functions except that there is no receive pipe. !*/ template < typename R > void reconfigure ( listen_on_port network_parameters, bridge_receive_decoration receive_pipe ); /*! ensures - This function is identical to the above three reconfigure() functions except that there is no transmit pipe. !*/ template void reconfigure ( connect_to_ip_and_port network_parameters, bridge_transmit_decoration transmit_pipe, bridge_receive_decoration receive_pipe ); /*! ensures - This object will begin making TCP connection attempts to the IP address and port specified by network_parameters. Any previous bridge state is cleared out. - Onces a connection is established we will: - Stop attempting new connections. - Begin dequeuing objects from the transmit pipe and serializing them over the TCP connection. - Begin deserializing objects from the TCP connection and enqueueing them onto the receive pipe. - if (the current TCP connection is lost) then - This object goes back to attempting to make a TCP connection with the IP address and port specified by network_parameters. - if (the receive pipe can contain bridge_status objects) then - Whenever the bridge's status changes the updated bridge_status will be enqueued onto the receive pipe unless the change was a TCP disconnect resulting from a user calling reconfigure(), clear(), or destructing this bridge. The status contents are defined by get_bridge_status(). !*/ template void reconfigure ( connect_to_ip_and_port network_parameters, bridge_receive_decoration receive_pipe, bridge_transmit_decoration transmit_pipe ); /*! ensures - performs reconfigure(network_parameters, transmit_pipe, receive_pipe) !*/ template void reconfigure ( connect_to_ip_and_port network_parameters, bridge_transmit_decoration transmit_pipe ); /*! ensures - This function is identical to the above two reconfigure() functions except that there is no receive pipe. !*/ template void reconfigure ( connect_to_ip_and_port network_parameters, bridge_receive_decoration receive_pipe ); /*! ensures - This function is identical to the above three reconfigure() functions except that there is no transmit pipe. !*/ }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_BRIDGe_ABSTRACT_ ================================================ FILE: dlib/bridge.h ================================================ // Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifdef DLIB_ALL_SOURCE_END #include "dlib_basic_cpp_build_tutorial.txt" #endif #ifndef DLIB_BRIdGE_ #define DLIB_BRIdGE_ #include "bridge/bridge.h" #endif // DLIB_BRIdGE_ ================================================ FILE: dlib/bsp/bsp.cpp ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BSP_CPph_ #define DLIB_BSP_CPph_ #include "bsp.h" #include #include // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- namespace dlib { namespace impl1 { void connect_all ( map_id_to_con& cons, const std::vector& hosts, unsigned long node_id ) { cons.clear(); for (unsigned long i = 0; i < hosts.size(); ++i) { std::unique_ptr con(new bsp_con(hosts[i])); dlib::serialize(node_id, con->stream); // tell the other end our node_id unsigned long id = i+1; cons.add(id, con); } } void connect_all_hostinfo ( map_id_to_con& cons, const std::vector& hosts, unsigned long node_id, std::string& error_string ) { cons.clear(); for (unsigned long i = 0; i < hosts.size(); ++i) { try { std::unique_ptr con(new bsp_con(hosts[i].addr)); dlib::serialize(node_id, con->stream); // tell the other end our node_id con->stream.flush(); unsigned long id = hosts[i].node_id; cons.add(id, con); } catch (std::exception&) { std::ostringstream sout; sout << "Could not connect to " << hosts[i].addr; error_string = sout.str(); break; } } } void send_out_connection_orders ( map_id_to_con& cons, const std::vector& hosts ) { // tell everyone their node ids cons.reset(); while (cons.move_next()) { dlib::serialize(cons.element().key(), cons.element().value()->stream); } // now tell them who to connect to std::vector targets; for (unsigned long i = 0; i < hosts.size(); ++i) { hostinfo info(hosts[i], i+1); dlib::serialize(targets, cons[info.node_id]->stream); targets.push_back(info); // let the other host know how many incoming connections to expect const unsigned long num = hosts.size()-targets.size(); dlib::serialize(num, cons[info.node_id]->stream); cons[info.node_id]->stream.flush(); } } // ------------------------------------------------------------------------------------ } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- namespace impl2 { // These control bytes are sent before each message between nodes. Note that many // of these are only sent between the control node (node 0) and the other nodes. // This is because the controller node is responsible for handling the // synchronization that needs to happen when all nodes block on calls to // receive_data() // at the same time. // denotes a normal content message. const static char MESSAGE_HEADER = 0; // sent to the controller node when someone receives a message via receive_data(). const static char GOT_MESSAGE = 1; // sent to the controller node when someone sends a message via send(). const static char SENT_MESSAGE = 2; // sent to the controller node when someone enters a call to receive_data() const static char IN_WAITING_STATE = 3; // broadcast when a node terminates itself. const static char NODE_TERMINATE = 5; // broadcast by the controller node when it determines that all nodes are blocked // on calls to receive_data() and there aren't any messages in flight. This is also // what makes us go to the next epoch. const static char SEE_ALL_IN_WAITING_STATE = 6; // This isn't ever transmitted between nodes. It is used internally to indicate // that an error occurred. const static char READ_ERROR = 7; // ------------------------------------------------------------------------------------ void read_thread ( impl1::bsp_con* con, unsigned long node_id, unsigned long sender_id, impl1::thread_safe_message_queue& msg_buffer ) { try { while(true) { impl1::msg_data msg; deserialize(msg.msg_type, con->stream); msg.sender_id = sender_id; if (msg.msg_type == MESSAGE_HEADER) { msg.data.reset(new std::vector); deserialize(msg.epoch, con->stream); deserialize(*msg.data, con->stream); } msg_buffer.push_and_consume(msg); if (msg.msg_type == NODE_TERMINATE) break; } } catch (std::exception& e) { impl1::msg_data msg; msg.data.reset(new std::vector); vectorstream sout(*msg.data); sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; sout << " Receiving processing node id: " << node_id << std::endl; sout << " Error message in the exception: " << e.what() << std::endl; msg.sender_id = sender_id; msg.msg_type = READ_ERROR; msg_buffer.push_and_consume(msg); } catch (...) { impl1::msg_data msg; msg.data.reset(new std::vector); vectorstream sout(*msg.data); sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; sout << " Receiving processing node id: " << node_id << std::endl; msg.sender_id = sender_id; msg.msg_type = READ_ERROR; msg_buffer.push_and_consume(msg); } } // ------------------------------------------------------------------------------------ } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // IMPLEMENTATION OF bsp_context OBJECT MEMBERS // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- void bsp_context:: close_all_connections_gracefully( ) { if (node_id() != 0) { _cons.reset(); while (_cons.move_next()) { // tell the other end that we are intentionally dropping the connection serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); _cons.element().value()->stream.flush(); } } impl1::msg_data msg; // now wait for all the other nodes to terminate while (num_terminated_nodes < _cons.size() ) { if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0) { num_waiting_nodes = 0; broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); ++current_epoch; } if (!msg_buffer.pop(msg)) throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); if (msg.msg_type == impl2::NODE_TERMINATE) { ++num_terminated_nodes; _cons[msg.sender_id]->terminated = true; } else if (msg.msg_type == impl2::READ_ERROR) { throw dlib::socket_error(msg.data_to_string()); } else if (msg.msg_type == impl2::MESSAGE_HEADER) { throw dlib::socket_error("A BSP node received a message after it has terminated."); } else if (msg.msg_type == impl2::GOT_MESSAGE) { --num_waiting_nodes; --outstanding_messages; } else if (msg.msg_type == impl2::SENT_MESSAGE) { ++outstanding_messages; } else if (msg.msg_type == impl2::IN_WAITING_STATE) { ++num_waiting_nodes; } } if (node_id() == 0) { _cons.reset(); while (_cons.move_next()) { // tell the other end that we are intentionally dropping the connection serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); _cons.element().value()->stream.flush(); } if (outstanding_messages != 0) { std::ostringstream sout; sout << "A BSP job was allowed to terminate before all sent messages have been received.\n"; sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n"; sout << "have a corresponding call to receive()."; throw dlib::socket_error(sout.str()); } } } // ---------------------------------------------------------------------------------------- bsp_context:: ~bsp_context() { _cons.reset(); while (_cons.move_next()) { _cons.element().value()->con->shutdown(); } msg_buffer.disable(); // this will wait for all the threads to terminate threads.clear(); } // ---------------------------------------------------------------------------------------- bsp_context:: bsp_context( unsigned long node_id_, impl1::map_id_to_con& cons_ ) : outstanding_messages(0), num_waiting_nodes(0), num_terminated_nodes(0), current_epoch(1), _cons(cons_), _node_id(node_id_) { // spawn a bunch of read threads, one for each connection _cons.reset(); while (_cons.move_next()) { std::unique_ptr ptr(new thread_function(&impl2::read_thread, _cons.element().value().get(), _node_id, _cons.element().key(), ref(msg_buffer))); threads.push_back(ptr); } } // ---------------------------------------------------------------------------------------- bool bsp_context:: receive_data ( std::shared_ptr >& item, unsigned long& sending_node_id ) { notify_control_node(impl2::IN_WAITING_STATE); while (true) { // If there aren't any nodes left to give us messages then return right now. // We need to check the msg_buffer size to make sure there aren't any // unprocessed message there. Recall that this can happen because status // messages always jump to the front of the message buffer. So we might have // learned about the node terminations before processing their messages for us. if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0) { return false; } // if all running nodes are currently blocking forever on receive_data() if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size()) { num_waiting_nodes = 0; broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); // Note that the reason we have this epoch counter is so we can tell if a // sent message is from before or after one of these "all nodes waiting" // synchronization events. If we didn't have the epoch count we would have // a race condition where one node gets the SEE_ALL_IN_WAITING_STATE // message before others and then sends out a message to another node // before that node got the SEE_ALL_IN_WAITING_STATE message. Then that // node would think the normal message came before SEE_ALL_IN_WAITING_STATE // which would be bad. ++current_epoch; return false; } impl1::msg_data data; if (!msg_buffer.pop(data, current_epoch)) throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); switch(data.msg_type) { case impl2::MESSAGE_HEADER: { item = data.data; sending_node_id = data.sender_id; notify_control_node(impl2::GOT_MESSAGE); return true; } break; case impl2::IN_WAITING_STATE: { ++num_waiting_nodes; } break; case impl2::GOT_MESSAGE: { --outstanding_messages; --num_waiting_nodes; } break; case impl2::SENT_MESSAGE: { ++outstanding_messages; } break; case impl2::NODE_TERMINATE: { ++num_terminated_nodes; _cons[data.sender_id]->terminated = true; } break; case impl2::SEE_ALL_IN_WAITING_STATE: { ++current_epoch; return false; } break; case impl2::READ_ERROR: { throw dlib::socket_error(data.data_to_string()); } break; default: { throw dlib::socket_error("Unknown message received by dlib::bsp_context"); } break; } // end switch() } // end while (true) } // ---------------------------------------------------------------------------------------- void bsp_context:: notify_control_node ( char val ) { if (node_id() == 0) { using namespace impl2; switch(val) { case SENT_MESSAGE: { ++outstanding_messages; } break; case GOT_MESSAGE: { --outstanding_messages; } break; case IN_WAITING_STATE: { // nothing to do in this case } break; default: DLIB_CASSERT(false,"This should never happen"); } } else { serialize(val, _cons[0]->stream); _cons[0]->stream.flush(); } } // ---------------------------------------------------------------------------------------- void bsp_context:: broadcast_byte ( char val ) { for (unsigned long i = 0; i < number_of_nodes(); ++i) { // don't send to yourself or to terminated nodes if (i == node_id() || _cons[i]->terminated) continue; serialize(val, _cons[i]->stream); _cons[i]->stream.flush(); } } // ---------------------------------------------------------------------------------------- void bsp_context:: send_data( const std::vector& item, unsigned long target_node_id ) { using namespace impl2; if (_cons[target_node_id]->terminated) throw socket_error("Attempt to send a message to a node that has terminated."); serialize(MESSAGE_HEADER, _cons[target_node_id]->stream); serialize(current_epoch, _cons[target_node_id]->stream); serialize(item, _cons[target_node_id]->stream); _cons[target_node_id]->stream.flush(); notify_control_node(SENT_MESSAGE); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BSP_CPph_ ================================================ FILE: dlib/bsp/bsp.h ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BsP_Hh_ #define DLIB_BsP_Hh_ #include "bsp_abstract.h" #include #include #include #include "../sockets.h" #include "../array.h" #include "../sockstreambuf.h" #include "../string.h" #include "../serialize.h" #include "../map.h" #include "../ref.h" #include "../vectorstream.h" namespace dlib { // ---------------------------------------------------------------------------------------- namespace impl1 { inline void null_notify( unsigned short ) {} struct bsp_con { bsp_con( const network_address& dest ) : con(connect(dest)), buf(con), stream(&buf), terminated(false) { con->disable_nagle(); } bsp_con( std::unique_ptr& conptr ) : buf(conptr), stream(&buf), terminated(false) { // make sure we own the connection conptr.swap(con); con->disable_nagle(); } std::unique_ptr con; sockstreambuf buf; std::iostream stream; bool terminated; }; typedef dlib::map >::kernel_1a_c map_id_to_con; void connect_all ( map_id_to_con& cons, const std::vector& hosts, unsigned long node_id ); /*! ensures - creates connections to all the given hosts and stores them into cons !*/ void send_out_connection_orders ( map_id_to_con& cons, const std::vector& hosts ); // ------------------------------------------------------------------------------------ struct hostinfo { hostinfo() {} hostinfo ( const network_address& addr_, unsigned long node_id_ ) : addr(addr_), node_id(node_id_) { } network_address addr; unsigned long node_id; }; inline void serialize ( const hostinfo& item, std::ostream& out ) { dlib::serialize(item.addr, out); dlib::serialize(item.node_id, out); } inline void deserialize ( hostinfo& item, std::istream& in ) { dlib::deserialize(item.addr, in); dlib::deserialize(item.node_id, in); } // ------------------------------------------------------------------------------------ void connect_all_hostinfo ( map_id_to_con& cons, const std::vector& hosts, unsigned long node_id, std::string& error_string ); // ------------------------------------------------------------------------------------ template < typename port_notify_function_type > void listen_and_connect_all( unsigned long& node_id, map_id_to_con& cons, unsigned short port, port_notify_function_type port_notify_function ) { cons.clear(); std::unique_ptr list; const int status = create_listener(list, port); if (status == PORTINUSE) { throw socket_error("Unable to create listening port " + cast_to_string(port) + ". The port is already in use"); } else if (status != 0) { throw socket_error("Unable to create listening port " + cast_to_string(port) ); } port_notify_function(list->get_listening_port()); std::unique_ptr con; if (list->accept(con)) { throw socket_error("Error occurred while accepting new connection"); } std::unique_ptr temp(new bsp_con(con)); unsigned long remote_node_id; dlib::deserialize(remote_node_id, temp->stream); dlib::deserialize(node_id, temp->stream); std::vector targets; dlib::deserialize(targets, temp->stream); unsigned long num_incoming_connections; dlib::deserialize(num_incoming_connections, temp->stream); cons.add(remote_node_id,temp); // make a thread that will connect to all the targets map_id_to_con cons2; std::string error_string; thread_function thread(connect_all_hostinfo, dlib::ref(cons2), dlib::ref(targets), node_id, dlib::ref(error_string)); if (error_string.size() != 0) throw socket_error(error_string); // accept any incoming connections for (unsigned long i = 0; i < num_incoming_connections; ++i) { // If it takes more than 10 seconds for the other nodes to connect to us // then something has gone horribly wrong and it almost certainly will // never connect at all. So just give up if that happens. const unsigned long timeout_milliseconds = 10000; if (list->accept(con, timeout_milliseconds)) { throw socket_error("Error occurred while accepting new connection"); } temp.reset(new bsp_con(con)); dlib::deserialize(remote_node_id, temp->stream); cons.add(remote_node_id,temp); } // put all the connections created by the thread into cons thread.wait(); while (cons2.size() > 0) { unsigned long id; std::unique_ptr temp; cons2.remove_any(id,temp); cons.add(id,temp); } } // ------------------------------------------------------------------------------------ struct msg_data { std::shared_ptr > data; unsigned long sender_id; char msg_type; dlib::uint64 epoch; msg_data() : sender_id(0xFFFFFFFF), msg_type(-1), epoch(0) {} std::string data_to_string() const { if (data && data->size() != 0) return std::string(&(*data)[0], data->size()); else return ""; } }; // ------------------------------------------------------------------------------------ class thread_safe_message_queue : noncopyable { /*! WHAT THIS OBJECT REPRESENTS This is a simple message queue for msg_data objects. Note that it has the special property that, while messages will generally leave the queue in the order they are inserted, any message with a smaller epoch value will always be popped out first. But for all messages with equal epoch values the queue functions as a normal FIFO queue. !*/ private: struct msg_wrap { msg_wrap( const msg_data& data_, const dlib::uint64& sequence_number_ ) : data(data_), sequence_number(sequence_number_) {} msg_wrap() : sequence_number(0){} msg_data data; dlib::uint64 sequence_number; // Make it so that when msg_wrap objects are in a std::priority_queue, // messages with a smaller epoch number always come first. Then, within an // epoch, messages are ordered by their sequence number (so smaller first // there as well). bool operator<(const msg_wrap& item) const { if (data.epoch < item.data.epoch) { return false; } else if (data.epoch > item.data.epoch) { return true; } else { if (sequence_number < item.sequence_number) return false; else return true; } } }; public: thread_safe_message_queue() : sig(class_mutex),disabled(false),next_seq_num(1) {} ~thread_safe_message_queue() { disable(); } void disable() { auto_mutex lock(class_mutex); disabled = true; sig.broadcast(); } unsigned long size() const { auto_mutex lock(class_mutex); return data.size(); } void push_and_consume( msg_data& item) { auto_mutex lock(class_mutex); data.push(msg_wrap(item, next_seq_num++)); // do this here so that we don't have to worry about different threads touching the shared_ptr. item.data.reset(); sig.signal(); } bool pop ( msg_data& item ) /*! ensures - if (this function returns true) then - #item == the next thing from the queue - else - this object is disabled !*/ { auto_mutex lock(class_mutex); while (data.size() == 0 && !disabled) sig.wait(); if (disabled) return false; item = data.top().data; data.pop(); return true; } bool pop ( msg_data& item, const dlib::uint64& max_epoch ) /*! ensures - if (this function returns true) then - #item == the next thing from the queue that has an epoch <= max_epoch - else - this object is disabled !*/ { auto_mutex lock(class_mutex); while ((data.size() == 0 || data.top().data.epoch > max_epoch) && !disabled) sig.wait(); if (disabled) return false; item = data.top().data; data.pop(); return true; } private: std::priority_queue data; dlib::mutex class_mutex; dlib::signaler sig; bool disabled; dlib::uint64 next_seq_num; }; } // ---------------------------------------------------------------------------------------- class bsp_context : noncopyable { public: template void send( const T& item, unsigned long target_node_id ) { // make sure requires clause is not broken DLIB_CASSERT(target_node_id < number_of_nodes() && target_node_id != node_id(), "\t void bsp_context::send()" << "\n\t Invalid arguments were given to this function." << "\n\t target_node_id: " << target_node_id << "\n\t node_id(): " << node_id() << "\n\t number_of_nodes(): " << number_of_nodes() << "\n\t this: " << this ); std::vector buf; vectorstream sout(buf); serialize(item, sout); send_data(buf, target_node_id); } template void broadcast ( const T& item ) { std::vector buf; vectorstream sout(buf); serialize(item, sout); for (unsigned long i = 0; i < number_of_nodes(); ++i) { // Don't send to yourself. if (i == node_id()) continue; send_data(buf, i); } } unsigned long node_id ( ) const { return _node_id; } unsigned long number_of_nodes ( ) const { return _cons.size()+1; } void receive ( ) { unsigned long id; std::shared_ptr > temp; if (receive_data(temp,id)) throw dlib::socket_error("Call to bsp_context::receive() got an unexpected message."); } template void receive ( T& item ) { if(!try_receive(item)) throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked."); } template bool try_receive ( T& item ) { unsigned long sending_node_id; return try_receive(item, sending_node_id); } template void receive ( T& item, unsigned long& sending_node_id ) { if(!try_receive(item, sending_node_id)) throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked."); } template bool try_receive ( T& item, unsigned long& sending_node_id ) { std::shared_ptr > temp; if (receive_data(temp, sending_node_id)) { vectorstream sin(*temp); deserialize(item, sin); if (sin.peek() != EOF) throw serialization_error("deserialize() did not consume all bytes produced by serialize(). " "This probably means you are calling a receive method with a different type " "of object than the one which was sent."); return true; } else { return false; } } ~bsp_context(); private: bsp_context(); bsp_context( unsigned long node_id_, impl1::map_id_to_con& cons_ ); void close_all_connections_gracefully(); /*! ensures - closes all the connections to other nodes and lets them know that we are terminating normally rather than as the result of some kind of error. !*/ bool receive_data ( std::shared_ptr >& item, unsigned long& sending_node_id ); void notify_control_node ( char val ); void broadcast_byte ( char val ); void send_data( const std::vector& item, unsigned long target_node_id ); /*! requires - target_node_id < number_of_nodes() - target_node_id != node_id() ensures - sends a copy of item to the node with the given id. !*/ unsigned long outstanding_messages; unsigned long num_waiting_nodes; unsigned long num_terminated_nodes; dlib::uint64 current_epoch; impl1::thread_safe_message_queue msg_buffer; impl1::map_id_to_con& _cons; const unsigned long _node_id; array > threads; // ----------------------------------- template < typename funct_type > friend void bsp_connect ( const std::vector& hosts, funct_type funct ); template < typename funct_type, typename ARG1 > friend void bsp_connect ( const std::vector& hosts, funct_type funct, ARG1 arg1 ); template < typename funct_type, typename ARG1, typename ARG2 > friend void bsp_connect ( const std::vector& hosts, funct_type funct, ARG1 arg1, ARG2 arg2 ); template < typename funct_type, typename ARG1, typename ARG2, typename ARG3 > friend void bsp_connect ( const std::vector& hosts, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ); template < typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > friend void bsp_connect ( const std::vector& hosts, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ); // ----------------------------------- template < typename port_notify_function_type, typename funct_type > friend void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct ); template < typename port_notify_function_type, typename funct_type, typename ARG1 > friend void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1 ); template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2 > friend void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2 ); template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2, typename ARG3 > friend void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ); template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > friend void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ); // ----------------------------------- }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename funct_type > void bsp_connect ( const std::vector& hosts, funct_type funct ) { impl1::map_id_to_con cons; const unsigned long node_id = 0; connect_all(cons, hosts, node_id); send_out_connection_orders(cons, hosts); bsp_context obj(node_id, cons); funct(obj); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1 > void bsp_connect ( const std::vector& hosts, funct_type funct, ARG1 arg1 ) { impl1::map_id_to_con cons; const unsigned long node_id = 0; connect_all(cons, hosts, node_id); send_out_connection_orders(cons, hosts); bsp_context obj(node_id, cons); funct(obj,arg1); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2 > void bsp_connect ( const std::vector& hosts, funct_type funct, ARG1 arg1, ARG2 arg2 ) { impl1::map_id_to_con cons; const unsigned long node_id = 0; connect_all(cons, hosts, node_id); send_out_connection_orders(cons, hosts); bsp_context obj(node_id, cons); funct(obj,arg1,arg2); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2, typename ARG3 > void bsp_connect ( const std::vector& hosts, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ) { impl1::map_id_to_con cons; const unsigned long node_id = 0; connect_all(cons, hosts, node_id); send_out_connection_orders(cons, hosts); bsp_context obj(node_id, cons); funct(obj,arg1,arg2,arg3); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > void bsp_connect ( const std::vector& hosts, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ) { impl1::map_id_to_con cons; const unsigned long node_id = 0; connect_all(cons, hosts, node_id); send_out_connection_orders(cons, hosts); bsp_context obj(node_id, cons); funct(obj,arg1,arg2,arg3,arg4); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename funct_type > void bsp_listen ( unsigned short listening_port, funct_type funct ) { // make sure requires clause is not broken DLIB_CASSERT(listening_port != 0, "\t void bsp_listen()" << "\n\t Invalid arguments were given to this function." ); bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1 > void bsp_listen ( unsigned short listening_port, funct_type funct, ARG1 arg1 ) { // make sure requires clause is not broken DLIB_CASSERT(listening_port != 0, "\t void bsp_listen()" << "\n\t Invalid arguments were given to this function." ); bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2 > void bsp_listen ( unsigned short listening_port, funct_type funct, ARG1 arg1, ARG2 arg2 ) { // make sure requires clause is not broken DLIB_CASSERT(listening_port != 0, "\t void bsp_listen()" << "\n\t Invalid arguments were given to this function." ); bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2, typename ARG3 > void bsp_listen ( unsigned short listening_port, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ) { // make sure requires clause is not broken DLIB_CASSERT(listening_port != 0, "\t void bsp_listen()" << "\n\t Invalid arguments were given to this function." ); bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > void bsp_listen ( unsigned short listening_port, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ) { // make sure requires clause is not broken DLIB_CASSERT(listening_port != 0, "\t void bsp_listen()" << "\n\t Invalid arguments were given to this function." ); bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3, arg4); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct ) { impl1::map_id_to_con cons; unsigned long node_id; listen_and_connect_all(node_id, cons, listening_port, port_notify_function); bsp_context obj(node_id, cons); funct(obj); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type, typename ARG1 > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1 ) { impl1::map_id_to_con cons; unsigned long node_id; listen_and_connect_all(node_id, cons, listening_port, port_notify_function); bsp_context obj(node_id, cons); funct(obj,arg1); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2 > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2 ) { impl1::map_id_to_con cons; unsigned long node_id; listen_and_connect_all(node_id, cons, listening_port, port_notify_function); bsp_context obj(node_id, cons); funct(obj,arg1,arg2); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2, typename ARG3 > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ) { impl1::map_id_to_con cons; unsigned long node_id; listen_and_connect_all(node_id, cons, listening_port, port_notify_function); bsp_context obj(node_id, cons); funct(obj,arg1,arg2,arg3); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ) { impl1::map_id_to_con cons; unsigned long node_id; listen_and_connect_all(node_id, cons, listening_port, port_notify_function); bsp_context obj(node_id, cons); funct(obj,arg1,arg2,arg3,arg4); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- } #ifdef NO_MAKEFILE #include "bsp.cpp" #endif #endif // DLIB_BsP_Hh_ ================================================ FILE: dlib/bsp/bsp_abstract.h ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_BsP_ABSTRACT_Hh_ #ifdef DLIB_BsP_ABSTRACT_Hh_ #include "../noncopyable.h" #include "../sockets/sockets_extensions_abstract.h" #include namespace dlib { // ---------------------------------------------------------------------------------------- class bsp_context : noncopyable { /*! WHAT THIS OBJECT REPRESENTS This is a tool used to implement algorithms using the Bulk Synchronous Parallel (BSP) computing model. A BSP algorithm is composed of a number of processing nodes, each executing in parallel. The general flow of execution in each processing node is the following: 1. Do work locally on some data. 2. Send some messages to other nodes. 3. Receive messages from other nodes. 4. Go to step 1 or terminate if complete. To do this, each processing node needs an API used to send and receive messages. This API is implemented by the bsp_connect object which provides these services to a BSP node. Note that BSP processing nodes are spawned using the bsp_connect() and bsp_listen() routines defined at the bottom of this file. For example, to start a BSP algorithm consisting of N processing nodes, you would make N-1 calls to bsp_listen() and one call to bsp_connect(). The call to bsp_connect() then initiates the computation on all nodes. Finally, note that there is no explicit barrier synchronization function you call at the end of step 3. Instead, you can simply call a method such as try_receive() until it returns false. That is, the bsp_context's receive methods incorporate a barrier synchronization that happens once all the BSP nodes are blocked on receive calls and there are no more messages in flight. THREAD SAFETY This object is not thread-safe. In particular, you should only ever have one thread that works with an instance of this object. This means that, for example, you should not spawn sub-threads from within a BSP processing node and have them invoke methods on this object. Instead, you should only invoke this object's methods from within the BSP processing node's main thread (i.e. the thread that executes the user supplied function funct()). !*/ public: template void send( const T& item, unsigned long target_node_id ); /*! requires - item is serializable - target_node_id < number_of_nodes() - target_node_id != node_id() ensures - sends a copy of item to the node with the given id. throws - dlib::socket_error: This exception is thrown if there is an error which prevents us from delivering the message to the given node. One way this might happen is if the target node has already terminated its execution or has lost network connectivity. !*/ template void broadcast ( const T& item ); /*! ensures - item is serializable - sends a copy of item to all other processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents us from delivering a message to one of the other nodes. This might happen, for example, if one of the nodes has terminated its execution or has lost network connectivity. !*/ unsigned long node_id ( ) const; /*! ensures - Returns the id of the current processing node. That is, returns a number N such that: - N < number_of_nodes() - N == the node id of the processing node that called node_id(). This is a number that uniquely identifies the processing node. !*/ unsigned long number_of_nodes ( ) const; /*! ensures - returns the number of processing nodes participating in the BSP computation. !*/ template bool try_receive ( T& item ); /*! requires - item is serializable ensures - if (this function returns true) then - #item == the next message which was sent to the calling processing node. - else - The following must have been true for this function to return false: - All other nodes were blocked on calls to receive(), try_receive(), or have terminated. - There were not any messages in flight between any nodes. - That is, if all the nodes had continued to block on receive methods then they all would have blocked forever. Therefore, this function only returns false once there are no more messages to process by any node and there is no possibility of more being generated until control is returned to the callers of receive methods. - When one BSP node's receive method returns because of the above conditions then all of them will also return. That is, it is NOT the case that just a subset of BSP nodes unblock. Moreover, they all unblock at the same time. throws - dlib::socket_error: This exception is thrown if some error occurs which prevents us from communicating with other processing nodes. - dlib::serialization_error or any exception thrown by the global deserialize(T) routine: This is thrown if there is a problem in deserialize(). This might happen if the message sent doesn't match the type T expected by try_receive(). !*/ template void receive ( T& item ); /*! requires - item is serializable ensures - #item == the next message which was sent to the calling processing node. - This function is just a wrapper around try_receive() that throws an exception if a message is not received (i.e. if try_receive() returns false). throws - dlib::socket_error: This exception is thrown if some error occurs which prevents us from communicating with other processing nodes or if there was not a message to receive. - dlib::serialization_error or any exception thrown by the global deserialize(T) routine: This is thrown if there is a problem in deserialize(). This might happen if the message sent doesn't match the type T expected by receive(). !*/ template bool try_receive ( T& item, unsigned long& sending_node_id ); /*! requires - item is serializable ensures - if (this function returns true) then - #item == the next message which was sent to the calling processing node. - #sending_node_id == the node id of the node that sent this message. - #sending_node_id < number_of_nodes() - else - The following must have been true for this function to return false: - All other nodes were blocked on calls to receive(), try_receive(), or have terminated. - There were not any messages in flight between any nodes. - That is, if all the nodes had continued to block on receive methods then they all would have blocked forever. Therefore, this function only returns false once there are no more messages to process by any node and there is no possibility of more being generated until control is returned to the callers of receive methods. - When one BSP node's receive method returns because of the above conditions then all of them will also return. That is, it is NOT the case that just a subset of BSP nodes unblock. Moreover, they all unblock at the same time. throws - dlib::socket_error: This exception is thrown if some error occurs which prevents us from communicating with other processing nodes. - dlib::serialization_error or any exception thrown by the global deserialize(T) routine: This is thrown if there is a problem in deserialize(). This might happen if the message sent doesn't match the type T expected by try_receive(). !*/ template void receive ( T& item, unsigned long& sending_node_id ); /*! requires - item is serializable ensures - #item == the next message which was sent to the calling processing node. - #sending_node_id == the node id of the node that sent this message. - #sending_node_id < number_of_nodes() - This function is just a wrapper around try_receive() that throws an exception if a message is not received (i.e. if try_receive() returns false). throws - dlib::socket_error: This exception is thrown if some error occurs which prevents us from communicating with other processing nodes or if there was not a message to receive. - dlib::serialization_error or any exception thrown by the global deserialize(T) routine: This is thrown if there is a problem in deserialize(). This might happen if the message sent doesn't match the type T expected by receive(). !*/ void receive ( ); /*! ensures - Waits for the following to all be true: - All other nodes were blocked on calls to receive(), try_receive(), or have terminated. - There are not any messages in flight between any nodes. - That is, if all the nodes had continued to block on receive methods then they all would have blocked forever. Therefore, this function only returns once there are no more messages to process by any node and there is no possibility of more being generated until control is returned to the callers of receive methods. - When one BSP node's receive method returns because of the above conditions then all of them will also return. That is, it is NOT the case that just a subset of BSP nodes unblock. Moreover, they all unblock at the same time. throws - dlib::socket_error: This exception is thrown if some error occurs which prevents us from communicating with other processing nodes or if a message is received before this function would otherwise return. !*/ }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename funct_type > void bsp_connect ( const std::vector& hosts, funct_type funct ); /*! requires - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT) must be a valid expression (i.e. funct must be a function or function object) ensures - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. - The processing node with a node ID of 0 will run locally on the machine calling bsp_connect(). In particular, this node will execute funct(CONTEXT), which is expected to carry out this node's portion of the BSP computation. - The other processing nodes are executed on the hosts indicated by the input argument. In particular, this function interprets hosts as a list addresses identifying machines running the bsp_listen() or bsp_listen_dynamic_port() routines. - This call to bsp_connect() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1 > void bsp_connect ( const std::vector& hosts, funct_type funct, ARG1 arg1 ); /*! requires - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT,arg1) must be a valid expression (i.e. funct must be a function or function object) ensures - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. - The processing node with a node ID of 0 will run locally on the machine calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1), which is expected to carry out this node's portion of the BSP computation. - The other processing nodes are executed on the hosts indicated by the input argument. In particular, this function interprets hosts as a list addresses identifying machines running the bsp_listen() or bsp_listen_dynamic_port() routines. - This call to bsp_connect() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2 > void bsp_connect ( const std::vector& hosts, funct_type funct, ARG1 arg1, ARG2 arg2 ); /*! requires - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT,arg1,arg2) must be a valid expression (i.e. funct must be a function or function object) ensures - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. - The processing node with a node ID of 0 will run locally on the machine calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2), which is expected to carry out this node's portion of the BSP computation. - The other processing nodes are executed on the hosts indicated by the input argument. In particular, this function interprets hosts as a list addresses identifying machines running the bsp_listen() or bsp_listen_dynamic_port() routines. - This call to bsp_connect() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2, typename ARG3 > void bsp_connect ( const std::vector& hosts, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ); /*! requires - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression (i.e. funct must be a function or function object) ensures - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. - The processing node with a node ID of 0 will run locally on the machine calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2,arg3), which is expected to carry out this node's portion of the BSP computation. - The other processing nodes are executed on the hosts indicated by the input argument. In particular, this function interprets hosts as a list addresses identifying machines running the bsp_listen() or bsp_listen_dynamic_port() routines. - This call to bsp_connect() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > void bsp_connect ( const std::vector& hosts, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ); /*! requires - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression (i.e. funct must be a function or function object) ensures - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. - The processing node with a node ID of 0 will run locally on the machine calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2,arg3,arg4), which is expected to carry out this node's portion of the BSP computation. - The other processing nodes are executed on the hosts indicated by the input argument. In particular, this function interprets hosts as a list addresses identifying machines running the bsp_listen() or bsp_listen_dynamic_port() routines. - This call to bsp_connect() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename funct_type > void bsp_listen ( unsigned short listening_port, funct_type funct ); /*! requires - listening_port != 0 - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT) must be a valid expression (i.e. funct must be a function or function object) ensures - This function listens for a connection from the bsp_connect() routine. Once this connection is established, funct(CONTEXT) will be executed and it will then be able to participate in the BSP computation as one of the processing nodes. - This function will listen on TCP port listening_port for a connection from bsp_connect(). Once the connection is established, it will close the listening port so it is free for use by other applications. The connection and BSP computation will continue uninterrupted. - This call to bsp_listen() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1 > void bsp_listen ( unsigned short listening_port, funct_type funct, ARG1 arg1 ); /*! requires - listening_port != 0 - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT,arg1) must be a valid expression (i.e. funct must be a function or function object) ensures - This function listens for a connection from the bsp_connect() routine. Once this connection is established, funct(CONTEXT,arg1) will be executed and it will then be able to participate in the BSP computation as one of the processing nodes. - This function will listen on TCP port listening_port for a connection from bsp_connect(). Once the connection is established, it will close the listening port so it is free for use by other applications. The connection and BSP computation will continue uninterrupted. - This call to bsp_listen() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2 > void bsp_listen ( unsigned short listening_port, funct_type funct, ARG1 arg1, ARG2 arg2 ); /*! requires - listening_port != 0 - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT,arg1,arg2) must be a valid expression (i.e. funct must be a function or function object) ensures - This function listens for a connection from the bsp_connect() routine. Once this connection is established, funct(CONTEXT,arg1,arg2) will be executed and it will then be able to participate in the BSP computation as one of the processing nodes. - This function will listen on TCP port listening_port for a connection from bsp_connect(). Once the connection is established, it will close the listening port so it is free for use by other applications. The connection and BSP computation will continue uninterrupted. - This call to bsp_listen() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2, typename ARG3 > void bsp_listen ( unsigned short listening_port, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ); /*! requires - listening_port != 0 - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression (i.e. funct must be a function or function object) ensures - This function listens for a connection from the bsp_connect() routine. Once this connection is established, funct(CONTEXT,arg1,arg2,arg3) will be executed and it will then be able to participate in the BSP computation as one of the processing nodes. - This function will listen on TCP port listening_port for a connection from bsp_connect(). Once the connection is established, it will close the listening port so it is free for use by other applications. The connection and BSP computation will continue uninterrupted. - This call to bsp_listen() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > void bsp_listen ( unsigned short listening_port, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ); /*! requires - listening_port != 0 - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression (i.e. funct must be a function or function object) ensures - This function listens for a connection from the bsp_connect() routine. Once this connection is established, funct(CONTEXT,arg1,arg2,arg3,arg4) will be executed and it will then be able to participate in the BSP computation as one of the processing nodes. - This function will listen on TCP port listening_port for a connection from bsp_connect(). Once the connection is established, it will close the listening port so it is free for use by other applications. The connection and BSP computation will continue uninterrupted. - This call to bsp_listen() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct ); /*! requires - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT) must be a valid expression (i.e. funct must be a function or function object) - port_notify_function((unsigned short) 1234) must be a valid expression (i.e. port_notify_function() must be a function or function object taking an unsigned short) ensures - This function listens for a connection from the bsp_connect() routine. Once this connection is established, funct(CONTEXT) will be executed and it will then be able to participate in the BSP computation as one of the processing nodes. - if (listening_port != 0) then - This function will listen on TCP port listening_port for a connection from bsp_connect(). - else - An available TCP port number is automatically selected and this function will listen on it for a connection from bsp_connect(). - Once a listening port is opened, port_notify_function() is called with the port number used. This provides a mechanism to find out what listening port has been used if it is automatically selected. It also allows you to find out when the routine has begun listening for an incoming connection from bsp_connect(). - Once a connection is established, we will close the listening port so it is free for use by other applications. The connection and BSP computation will continue uninterrupted. - This call to bsp_listen_dynamic_port() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type, typename ARG1 > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1 ); /*! requires - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT,arg1) must be a valid expression (i.e. funct must be a function or function object) - port_notify_function((unsigned short) 1234) must be a valid expression (i.e. port_notify_function() must be a function or function object taking an unsigned short) ensures - This function listens for a connection from the bsp_connect() routine. Once this connection is established, funct(CONTEXT,arg1) will be executed and it will then be able to participate in the BSP computation as one of the processing nodes. - if (listening_port != 0) then - This function will listen on TCP port listening_port for a connection from bsp_connect(). - else - An available TCP port number is automatically selected and this function will listen on it for a connection from bsp_connect(). - Once a listening port is opened, port_notify_function() is called with the port number used. This provides a mechanism to find out what listening port has been used if it is automatically selected. It also allows you to find out when the routine has begun listening for an incoming connection from bsp_connect(). - Once a connection is established, we will close the listening port so it is free for use by other applications. The connection and BSP computation will continue uninterrupted. - This call to bsp_listen_dynamic_port() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2 > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2 ); /*! requires - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT,arg1,arg2) must be a valid expression (i.e. funct must be a function or function object) - port_notify_function((unsigned short) 1234) must be a valid expression (i.e. port_notify_function() must be a function or function object taking an unsigned short) ensures - This function listens for a connection from the bsp_connect() routine. Once this connection is established, funct(CONTEXT,arg1,arg2) will be executed and it will then be able to participate in the BSP computation as one of the processing nodes. - if (listening_port != 0) then - This function will listen on TCP port listening_port for a connection from bsp_connect(). - else - An available TCP port number is automatically selected and this function will listen on it for a connection from bsp_connect(). - Once a listening port is opened, port_notify_function() is called with the port number used. This provides a mechanism to find out what listening port has been used if it is automatically selected. It also allows you to find out when the routine has begun listening for an incoming connection from bsp_connect(). - Once a connection is established, we will close the listening port so it is free for use by other applications. The connection and BSP computation will continue uninterrupted. - This call to bsp_listen_dynamic_port() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2, typename ARG3 > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ); /*! requires - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression (i.e. funct must be a function or function object) - port_notify_function((unsigned short) 1234) must be a valid expression (i.e. port_notify_function() must be a function or function object taking an unsigned short) ensures - This function listens for a connection from the bsp_connect() routine. Once this connection is established, funct(CONTEXT,arg1,arg2,arg3) will be executed and it will then be able to participate in the BSP computation as one of the processing nodes. - if (listening_port != 0) then - This function will listen on TCP port listening_port for a connection from bsp_connect(). - else - An available TCP port number is automatically selected and this function will listen on it for a connection from bsp_connect(). - Once a listening port is opened, port_notify_function() is called with the port number used. This provides a mechanism to find out what listening port has been used if it is automatically selected. It also allows you to find out when the routine has begun listening for an incoming connection from bsp_connect(). - Once a connection is established, we will close the listening port so it is free for use by other applications. The connection and BSP computation will continue uninterrupted. - This call to bsp_listen_dynamic_port() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ); /*! requires - let CONTEXT be an instance of a bsp_context object. Then: - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression (i.e. funct must be a function or function object) - port_notify_function((unsigned short) 1234) must be a valid expression (i.e. port_notify_function() must be a function or function object taking an unsigned short) ensures - This function listens for a connection from the bsp_connect() routine. Once this connection is established, funct(CONTEXT,arg1,arg2,arg3,arg4) will be executed and it will then be able to participate in the BSP computation as one of the processing nodes. - if (listening_port != 0) then - This function will listen on TCP port listening_port for a connection from bsp_connect(). - else - An available TCP port number is automatically selected and this function will listen on it for a connection from bsp_connect(). - Once a listening port is opened, port_notify_function() is called with the port number used. This provides a mechanism to find out what listening port has been used if it is automatically selected. It also allows you to find out when the routine has begun listening for an incoming connection from bsp_connect(). - Once a connection is established, we will close the listening port so it is free for use by other applications. The connection and BSP computation will continue uninterrupted. - This call to bsp_listen_dynamic_port() blocks until the BSP computation has completed on all processing nodes. throws - dlib::socket_error This exception is thrown if there is an error which prevents the BSP job from executing. - Any exception thrown by funct() will be propagated out of this call to bsp_connect(). !*/ // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- } #endif // DLIB_BsP_ABSTRACT_Hh_ ================================================ FILE: dlib/bsp.h ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BSPh_ #define DLIB_BSPh_ #include "bsp/bsp.h" #endif // DLIB_BSPh_ ================================================ FILE: dlib/byte_orderer/byte_orderer_kernel_1.h ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BYTE_ORDEREr_KERNEL_1_ #define DLIB_BYTE_ORDEREr_KERNEL_1_ #include "byte_orderer_kernel_abstract.h" #include "../algs.h" #include "../assert.h" namespace dlib { class byte_orderer { /*! INITIAL VALUE - if (this machine is little endian) then - little_endian == true - else - little_endian == false CONVENTION - host_is_big_endian() == !little_endian - host_is_little_endian() == little_endian - if (this machine is little endian) then - little_endian == true - else - little_endian == false !*/ public: // this is here for backwards compatibility with older versions of dlib. typedef byte_orderer kernel_1a; byte_orderer ( ) { // This will probably never be false but if it is then it means chars are not 8bits // on this system. Which is a problem for this object. COMPILE_TIME_ASSERT(sizeof(short) >= 2); unsigned long temp = 1; unsigned char* ptr = reinterpret_cast(&temp); if (*ptr == 1) little_endian = true; else little_endian = false; } virtual ~byte_orderer ( ){} bool host_is_big_endian ( ) const { return !little_endian; } bool host_is_little_endian ( ) const { return little_endian; } template < typename T > inline void host_to_network ( T& item ) const { if (little_endian) flip(item); } template < typename T > inline void network_to_host ( T& item ) const { if (little_endian) flip(item); } template < typename T > void host_to_big ( T& item ) const { if (little_endian) flip(item); } template < typename T > void big_to_host ( T& item ) const { if (little_endian) flip(item); } template < typename T > void host_to_little ( T& item ) const { if (!little_endian) flip(item); } template < typename T > void little_to_host ( T& item ) const { if (!little_endian) flip(item); } private: template < typename T, size_t size > inline void flip ( T (&array)[size] ) const /*! ensures - flips the bytes in every element of this array !*/ { for (size_t i = 0; i < size; ++i) { flip(array[i]); } } template < typename T > inline void flip ( T& item ) const /*! ensures - reverses the byte ordering in item !*/ { DLIB_ASSERT_HAS_STANDARD_LAYOUT(T); T value; // If you are getting this as an error then you are probably using // this object wrong. If you think you aren't then send me (Davis) an // email and I'll either set you straight or change/remove this check so // your stuff works :) COMPILE_TIME_ASSERT(sizeof(T) <= sizeof(long double)); // If you are getting a compile error on this line then it means T is // a pointer type. It doesn't make any sense to byte swap pointers // since they have no meaning outside the context of their own process. // So you probably just forgot to dereference that pointer before passing // it to this function :) COMPILE_TIME_ASSERT(is_pointer_type::value == false); const size_t size = sizeof(T); unsigned char* const ptr = reinterpret_cast(&item); unsigned char* const ptr_temp = reinterpret_cast(&value); for (size_t i = 0; i < size; ++i) ptr_temp[size-i-1] = ptr[i]; item = value; } bool little_endian; }; // make flip not do anything at all for chars template <> inline void byte_orderer::flip ( char& ) const {} template <> inline void byte_orderer::flip ( unsigned char& ) const {} template <> inline void byte_orderer::flip ( signed char& ) const {} } #endif // DLIB_BYTE_ORDEREr_KERNEL_1_ ================================================ FILE: dlib/byte_orderer/byte_orderer_kernel_abstract.h ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_BYTE_ORDEREr_ABSTRACT_ #ifdef DLIB_BYTE_ORDEREr_ABSTRACT_ #include "../algs.h" namespace dlib { class byte_orderer { /*! INITIAL VALUE This object has no state. WHAT THIS OBJECT REPRESENTS This object simply provides a mechanism to convert data from a host machine's own byte ordering to big or little endian and to also do the reverse. It also provides a pair of functions to convert to/from network byte order where network byte order is big endian byte order. This pair of functions does the exact same thing as the host_to_big() and big_to_host() functions and is provided simply so that client code can use the most self documenting name appropriate. Also note that this object is capable of correctly flipping the contents of arrays when the arrays are declared on the stack. e.g. You can say things like: int array[10]; bo.host_to_network(array); !*/ public: byte_orderer ( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc !*/ virtual ~byte_orderer ( ); /*! ensures - any resources associated with *this have been released !*/ bool host_is_big_endian ( ) const; /*! ensures - if (the host computer is a big endian machine) then - returns true - else - returns false !*/ bool host_is_little_endian ( ) const; /*! ensures - if (the host computer is a little endian machine) then - returns true - else - returns false !*/ template < typename T > void host_to_network ( T& item ) const; /*! ensures - #item == the value of item converted from host byte order to network byte order. !*/ template < typename T > void network_to_host ( T& item ) const; /*! ensures - #item == the value of item converted from network byte order to host byte order. !*/ template < typename T > void host_to_big ( T& item ) const; /*! ensures - #item == the value of item converted from host byte order to big endian byte order. !*/ template < typename T > void big_to_host ( T& item ) const; /*! ensures - #item == the value of item converted from big endian byte order to host byte order. !*/ template < typename T > void host_to_little ( T& item ) const; /*! ensures - #item == the value of item converted from host byte order to little endian byte order. !*/ template < typename T > void little_to_host ( T& item ) const; /*! ensures - #item == the value of item converted from little endian byte order to host byte order. !*/ }; } #endif // DLIB_BYTE_ORDEREr_ABSTRACT_ ================================================ FILE: dlib/byte_orderer.h ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BYTE_ORDEREr_ #define DLIB_BYTE_ORDEREr_ #include "byte_orderer/byte_orderer_kernel_1.h" #endif // DLIB_BYTE_ORDEREr_ ================================================ FILE: dlib/cassert ================================================ #include "dlib_include_path_tutorial.txt" ================================================ FILE: dlib/clustering/bottom_up_cluster.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BOTTOM_uP_CLUSTER_Hh_ #define DLIB_BOTTOM_uP_CLUSTER_Hh_ #include #include #include "bottom_up_cluster_abstract.h" #include "../algs.h" #include "../matrix.h" #include "../disjoint_subsets.h" #include "../graph_utils.h" namespace dlib { // ---------------------------------------------------------------------------------------- namespace buc_impl { inline void merge_sets ( matrix& dists, unsigned long dest, unsigned long src ) { for (long r = 0; r < dists.nr(); ++r) dists(dest,r) = dists(r,dest) = std::max(dists(r,dest), dists(r,src)); } struct compare_dist { bool operator() ( const sample_pair& a, const sample_pair& b ) const { return a.distance() > b.distance(); } }; } // ---------------------------------------------------------------------------------------- template < typename EXP > unsigned long bottom_up_cluster ( const matrix_exp& dists_, std::vector& labels, unsigned long min_num_clusters, double max_dist = std::numeric_limits::infinity() ) { matrix dists = matrix_cast(dists_); // make sure requires clause is not broken DLIB_CASSERT(dists.nr() == dists.nc() && min_num_clusters > 0, "\t unsigned long bottom_up_cluster()" << "\n\t Invalid inputs were given to this function." << "\n\t dists.nr(): " << dists.nr() << "\n\t dists.nc(): " << dists.nc() << "\n\t min_num_clusters: " << min_num_clusters ); using namespace buc_impl; labels.resize(dists.nr()); disjoint_subsets sets; sets.set_size(dists.nr()); if (labels.size() == 0) return 0; // push all the edges in the graph into a priority queue so the best edges to merge // come first. std::priority_queue, compare_dist> que; for (long r = 0; r < dists.nr(); ++r) for (long c = r+1; c < dists.nc(); ++c) que.push(sample_pair(r,c,dists(r,c))); // Now start merging nodes. for (unsigned long iter = min_num_clusters; iter < sets.size(); ++iter) { // find the next best thing to merge. double best_dist = que.top().distance(); unsigned long a = sets.find_set(que.top().index1()); unsigned long b = sets.find_set(que.top().index2()); que.pop(); // we have been merging and modifying the distances, so make sure this distance // is still valid and these guys haven't been merged already. while(a == b || best_dist < dists(a,b)) { // Haven't merged it yet, so put it back in with updated distance for // reconsideration later. if (a != b) que.push(sample_pair(a, b, dists(a, b))); best_dist = que.top().distance(); a = sets.find_set(que.top().index1()); b = sets.find_set(que.top().index2()); que.pop(); } // now merge these sets if the best distance is small enough if (best_dist > max_dist) break; unsigned long news = sets.merge_sets(a,b); unsigned long olds = (news==a)?b:a; merge_sets(dists, news, olds); } // figure out which cluster each element is in. Also make sure the labels are // contiguous. std::map relabel; for (unsigned long r = 0; r < labels.size(); ++r) { unsigned long l = sets.find_set(r); // relabel to make contiguous if (relabel.count(l) == 0) { unsigned long next = relabel.size(); relabel[l] = next; } labels[r] = relabel[l]; } return relabel.size(); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- struct snl_range { snl_range() = default; snl_range(double val) : lower(val), upper(val) {} snl_range(double l, double u) : lower(l), upper(u) { DLIB_ASSERT(lower <= upper)} double lower = 0; double upper = 0; double width() const { return upper-lower; } bool operator<(const snl_range& item) const { return lower < item.lower; } }; inline snl_range merge(const snl_range& a, const snl_range& b) { return snl_range(std::min(a.lower, b.lower), std::max(a.upper, b.upper)); } inline double distance (const snl_range& a, const snl_range& b) { return std::max(a.lower,b.lower) - std::min(a.upper,b.upper); } inline std::ostream& operator<< (std::ostream& out, const snl_range& item ) { out << "["< segment_number_line ( const std::vector& x, const double max_range_width ) { DLIB_CASSERT(max_range_width >= 0); // create initial ranges, one for each value in x. So initially, all the ranges have // width of 0. std::vector ranges; for (auto v : x) ranges.push_back(v); std::sort(ranges.begin(), ranges.end()); std::vector greedy_final_ranges; if (ranges.size() == 0) return greedy_final_ranges; // We will try two different clustering strategies. One that does a simple greedy left // to right sweep and another that does a bottom up agglomerative clustering. This // first loop runs the greedy left to right sweep. Then at the end of this routine we // will return the results that produced the tightest clustering. greedy_final_ranges.push_back(ranges[0]); for (size_t i = 1; i < ranges.size(); ++i) { auto m = merge(greedy_final_ranges.back(), ranges[i]); if (m.width() <= max_range_width) greedy_final_ranges.back() = m; else greedy_final_ranges.push_back(ranges[i]); } // Here we do the bottom up clustering. So compute the edges connecting our ranges. // We will simply say there are edges between ranges if and only if they are // immediately adjacent on the number line. std::vector edges; for (size_t i = 1; i < ranges.size(); ++i) edges.push_back(sample_pair(i-1,i, distance(ranges[i-1],ranges[i]))); std::sort(edges.begin(), edges.end(), order_by_distance); disjoint_subsets sets; sets.set_size(ranges.size()); // Now start merging nodes. for (auto edge : edges) { // find the next best thing to merge. unsigned long a = sets.find_set(edge.index1()); unsigned long b = sets.find_set(edge.index2()); // merge it if it doesn't result in an interval that's too big. auto m = merge(ranges[a], ranges[b]); if (m.width() <= max_range_width) { unsigned long news = sets.merge_sets(a,b); ranges[news] = m; } } // Now create a list of the final ranges. We will do this by keeping track of which // range we already added to final_ranges. std::vector final_ranges; std::vector already_output(ranges.size(), false); for (unsigned long i = 0; i < sets.size(); ++i) { auto s = sets.find_set(i); if (!already_output[s]) { final_ranges.push_back(ranges[s]); already_output[s] = true; } } // only use the greedy clusters if they found a clustering with fewer clusters. // Otherwise, the bottom up clustering probably produced a more sensible clustering. if (final_ranges.size() <= greedy_final_ranges.size()) return final_ranges; else return greedy_final_ranges; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BOTTOM_uP_CLUSTER_Hh_ ================================================ FILE: dlib/clustering/bottom_up_cluster_abstract.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ #ifdef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ #include "../matrix.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename EXP > unsigned long bottom_up_cluster ( const matrix_exp& dists, std::vector& labels, unsigned long min_num_clusters, double max_dist = std::numeric_limits::infinity() ); /*! requires - dists.nr() == dists.nc() - min_num_clusters > 0 - dists == trans(dists) (l.e. dists should be symmetric) ensures - Runs a bottom up agglomerative clustering algorithm. - Interprets dists as a matrix that gives the distances between dists.nr() items. In particular, we take dists(i,j) to be the distance between the ith and jth element of some set. This function clusters the elements of this set into at least min_num_clusters (or dists.nr() if there aren't enough elements). Additionally, within each cluster, the maximum pairwise distance between any two cluster elements is <= max_dist. - returns the number of clusters found. - #labels.size() == dists.nr() - for all valid i: - #labels[i] == the cluster ID of the node with index i (i.e. the node corresponding to the distances dists(i,*)). - 0 <= #labels[i] < the number of clusters found (i.e. cluster IDs are assigned contiguously and start at 0) !*/ // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- struct snl_range { /*! WHAT THIS OBJECT REPRESENTS This object represents an interval on the real number line. It is used to store the outputs of the segment_number_line() routine defined below. !*/ snl_range( ); /*! ensures - #lower == 0 - #upper == 0 !*/ snl_range( double val ); /*! ensures - #lower == val - #upper == val !*/ snl_range( double l, double u ); /*! requires - l <= u ensures - #lower == l - #upper == u !*/ double lower; double upper; double width( ) const { return upper-lower; } /*! ensures - returns the width of this interval on the number line. !*/ bool operator<(const snl_range& item) const { return lower < item.lower; } /*! ensures - provides a total ordering of snl_range objects assuming they are non-overlapping. !*/ }; std::ostream& operator<< (std::ostream& out, const snl_range& item ); /*! ensures - prints item to out in the form [lower,upper]. !*/ // ---------------------------------------------------------------------------------------- std::vector segment_number_line ( const std::vector& x, const double max_range_width ); /*! requires - max_range_width >= 0 ensures - Finds a clustering of the values in x and returns the ranges that define the clustering. This routine uses a combination of bottom up clustering and a simple greedy scan to try and find the most compact set of ranges that contain all the values in x. - This routine has approximately linear runtime. - Every value in x will be contained inside one of the returned snl_range objects; - All returned snl_range object's will have a width() <= max_range_width and will also be non-overlapping. !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ ================================================ FILE: dlib/clustering/chinese_whispers.h ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CHINESE_WHISPErS_Hh_ #define DLIB_CHINESE_WHISPErS_Hh_ #include "chinese_whispers_abstract.h" #include #include "../rand.h" #include "../graph_utils/edge_list_graphs.h" namespace dlib { // ---------------------------------------------------------------------------------------- inline unsigned long chinese_whispers ( const std::vector& edges, std::vector& labels, const unsigned long num_iterations, dlib::rand& rnd ) { // make sure requires clause is not broken DLIB_ASSERT(is_ordered_by_index(edges), "\t unsigned long chinese_whispers()" << "\n\t Invalid inputs were given to this function" ); labels.clear(); if (edges.size() == 0) return 0; std::vector > neighbors; find_neighbor_ranges(edges, neighbors); // Initialize the labels, each node gets a different label. labels.resize(neighbors.size()); for (unsigned long i = 0; i < labels.size(); ++i) labels[i] = i; for (unsigned long iter = 0; iter < neighbors.size()*num_iterations; ++iter) { // Pick a random node. const unsigned long idx = rnd.get_random_64bit_number()%neighbors.size(); // Count how many times each label happens amongst our neighbors. std::map labels_to_counts; const unsigned long end = neighbors[idx].second; for (unsigned long i = neighbors[idx].first; i != end; ++i) { labels_to_counts[labels[edges[i].index2()]] += edges[i].distance(); } // find the most common label std::map::iterator i; double best_score = -std::numeric_limits::infinity(); unsigned long best_label = labels[idx]; for (i = labels_to_counts.begin(); i != labels_to_counts.end(); ++i) { if (i->second > best_score) { best_score = i->second; best_label = i->first; } } labels[idx] = best_label; } // Remap the labels into a contiguous range. First we find the // mapping. std::map label_remap; for (unsigned long i = 0; i < labels.size(); ++i) { const unsigned long next_id = label_remap.size(); if (label_remap.count(labels[i]) == 0) label_remap[labels[i]] = next_id; } // now apply the mapping to all the labels. for (unsigned long i = 0; i < labels.size(); ++i) { labels[i] = label_remap[labels[i]]; } return label_remap.size(); } // ---------------------------------------------------------------------------------------- inline unsigned long chinese_whispers ( const std::vector& edges, std::vector& labels, const unsigned long num_iterations, dlib::rand& rnd ) { std::vector oedges; convert_unordered_to_ordered(edges, oedges); std::sort(oedges.begin(), oedges.end(), &order_by_index); return chinese_whispers(oedges, labels, num_iterations, rnd); } // ---------------------------------------------------------------------------------------- inline unsigned long chinese_whispers ( const std::vector& edges, std::vector& labels, const unsigned long num_iterations = 100 ) { dlib::rand rnd; return chinese_whispers(edges, labels, num_iterations, rnd); } // ---------------------------------------------------------------------------------------- inline unsigned long chinese_whispers ( const std::vector& edges, std::vector& labels, const unsigned long num_iterations = 100 ) { dlib::rand rnd; return chinese_whispers(edges, labels, num_iterations, rnd); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CHINESE_WHISPErS_Hh_ ================================================ FILE: dlib/clustering/chinese_whispers_abstract.h ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_ #ifdef DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_ #include #include "../rand.h" #include "../graph_utils/ordered_sample_pair_abstract.h" #include "../graph_utils/sample_pair_abstract.h" namespace dlib { // ---------------------------------------------------------------------------------------- unsigned long chinese_whispers ( const std::vector& edges, std::vector& labels, const unsigned long num_iterations, dlib::rand& rnd ); /*! requires - is_ordered_by_index(edges) == true ensures - This function implements the graph clustering algorithm described in the paper: Chinese Whispers - an Efficient Graph Clustering Algorithm and its Application to Natural Language Processing Problems by Chris Biemann. - Interprets edges as a directed graph. That is, it contains the edges on the said graph and the ordered_sample_pair::distance() values define the edge weights (larger values indicating a stronger edge connection between the nodes). If an edge has a distance() value of infinity then it is considered a "must link" edge. - returns the number of clusters found. - #labels.size() == max_index_plus_one(edges) - for all valid i: - #labels[i] == the cluster ID of the node with index i in the graph. - 0 <= #labels[i] < the number of clusters found (i.e. cluster IDs are assigned contiguously and start at 0) - Duplicate edges are interpreted as if there had been just one edge with a distance value equal to the sum of all the duplicate edge's distance values. - The algorithm performs exactly num_iterations passes over the graph before terminating. !*/ // ---------------------------------------------------------------------------------------- unsigned long chinese_whispers ( const std::vector& edges, std::vector& labels, const unsigned long num_iterations, dlib::rand& rnd ); /*! ensures - This function is identical to the above chinese_whispers() routine except that it operates on a vector of sample_pair objects instead of ordered_sample_pairs. Therefore, this is simply a convenience routine. In particular, it is implemented by transforming the given edges into ordered_sample_pairs and then calling the chinese_whispers() routine defined above. !*/ // ---------------------------------------------------------------------------------------- unsigned long chinese_whispers ( const std::vector& edges, std::vector& labels, const unsigned long num_iterations = 100 ); /*! requires - is_ordered_by_index(edges) == true ensures - performs: return chinese_whispers(edges, labels, num_iterations, rnd) where rnd is a default initialized dlib::rand object. !*/ // ---------------------------------------------------------------------------------------- unsigned long chinese_whispers ( const std::vector& edges, std::vector& labels, const unsigned long num_iterations = 100 ); /*! ensures - performs: return chinese_whispers(edges, labels, num_iterations, rnd) where rnd is a default initialized dlib::rand object. !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_ ================================================ FILE: dlib/clustering/modularity_clustering.h ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_MODULARITY_ClUSTERING__H__ #define DLIB_MODULARITY_ClUSTERING__H__ #include "modularity_clustering_abstract.h" #include "../sparse_vector.h" #include "../graph_utils/edge_list_graphs.h" #include "../matrix.h" #include "../rand.h" namespace dlib { // ----------------------------------------------------------------------------------------- namespace impl { inline double newman_cluster_split ( dlib::rand& rnd, const std::vector& edges, const matrix& node_degrees, // k from the Newman paper const matrix& Bdiag, // diag(B) from the Newman paper const double& edge_sum, // m from the Newman paper matrix& labels, const double eps, const unsigned long max_iterations ) /*! requires - node_degrees.size() == max_index_plus_one(edges) - Bdiag.size() == max_index_plus_one(edges) - edges must be sorted according to order_by_index() ensures - This routine splits a graph into two subgraphs using the Newman clustering method. - returns the modularity obtained when the graph is split according to the contents of #labels. - #labels.size() == node_degrees.size() - for all valid i: #labels(i) == -1 or +1 - if (this function returns 0) then - all the labels are equal, i.e. the graph is not split. !*/ { // Scale epsilon so that it is relative to the expected value of an element of a // unit vector of length node_degrees.size(). const double power_iter_eps = eps * std::sqrt(1.0/node_degrees.size()); // Make a random unit vector and put in labels. labels.set_size(node_degrees.size()); for (long i = 0; i < labels.size(); ++i) labels(i) = rnd.get_random_gaussian(); labels /= length(labels); matrix Bv, Bv_unit; // Do the power iteration for a while. double eig = -1; double offset = 0; while (eig < 0) { // any number larger than power_iter_eps double iteration_change = power_iter_eps*2+1; for (unsigned long i = 0; i < max_iterations && iteration_change > power_iter_eps; ++i) { sparse_matrix_vector_multiply(edges, labels, Bv); Bv -= dot(node_degrees, labels)/(2*edge_sum) * node_degrees; if (offset != 0) { Bv -= offset*labels; } const double len = length(Bv); if (len != 0) { Bv_unit = Bv/len; iteration_change = max(abs(labels-Bv_unit)); labels.swap(Bv_unit); } else { // Had a bad time, pick another random vector and try it with the // power iteration. for (long i = 0; i < labels.size(); ++i) labels(i) = rnd.get_random_gaussian(); } } eig = dot(Bv,labels); // we will repeat this loop if the largest eigenvalue is negative offset = eig; } for (long i = 0; i < labels.size(); ++i) { if (labels(i) > 0) labels(i) = 1; else labels(i) = -1; } // compute B*labels, store result in Bv. sparse_matrix_vector_multiply(edges, labels, Bv); Bv -= dot(node_degrees, labels)/(2*edge_sum) * node_degrees; // Do some label refinement. In this step we swap labels if it // improves the modularity score. bool flipped_label = true; while(flipped_label) { flipped_label = false; unsigned long idx = 0; for (long i = 0; i < labels.size(); ++i) { const double val = -2*labels(i); const double increase = 4*Bdiag(i) + 2*val*Bv(i); // if there is an increase in modularity for swapping this label if (increase > 0) { labels(i) *= -1; while (idx < edges.size() && edges[idx].index1() == (unsigned long)i) { const long j = edges[idx].index2(); Bv(j) += val*edges[idx].distance(); ++idx; } Bv -= (val*node_degrees(i)/(2*edge_sum))*node_degrees; flipped_label = true; } else { while (idx < edges.size() && edges[idx].index1() == (unsigned long)i) { ++idx; } } } } const double modularity = dot(Bv, labels)/(4*edge_sum); return modularity; } // ------------------------------------------------------------------------------------- inline unsigned long newman_cluster_helper ( dlib::rand& rnd, const std::vector& edges, const matrix& node_degrees, // k from the Newman paper const matrix& Bdiag, // diag(B) from the Newman paper const double& edge_sum, // m from the Newman paper std::vector& labels, double modularity_threshold, const double eps, const unsigned long max_iterations ) /*! ensures - returns the number of clusters the data was split into !*/ { matrix l; const double modularity = newman_cluster_split(rnd,edges,node_degrees,Bdiag,edge_sum,l,eps,max_iterations); // We need to collapse the node index values down to contiguous values. So // we use the following two vectors to contain the mappings from input index // values to their corresponding index values in each split. std::vector left_idx_map(node_degrees.size()); std::vector right_idx_map(node_degrees.size()); // figure out how many nodes went into each side of the split. unsigned long num_left_split = 0; unsigned long num_right_split = 0; for (long i = 0; i < l.size(); ++i) { if (l(i) > 0) { left_idx_map[i] = num_left_split; ++num_left_split; } else { right_idx_map[i] = num_right_split; ++num_right_split; } } // do a recursive split if it will improve the modularity. if (modularity > modularity_threshold && num_left_split > 0 && num_right_split > 0) { // split the node_degrees and Bdiag matrices into left and right split parts matrix left_node_degrees(num_left_split); matrix right_node_degrees(num_right_split); matrix left_Bdiag(num_left_split); matrix right_Bdiag(num_right_split); for (long i = 0; i < l.size(); ++i) { if (l(i) > 0) { left_node_degrees(left_idx_map[i]) = node_degrees(i); left_Bdiag(left_idx_map[i]) = Bdiag(i); } else { right_node_degrees(right_idx_map[i]) = node_degrees(i); right_Bdiag(right_idx_map[i]) = Bdiag(i); } } // put the edges from one side of the split into split_edges std::vector split_edges; modularity_threshold = 0; for (unsigned long k = 0; k < edges.size(); ++k) { const unsigned long i = edges[k].index1(); const unsigned long j = edges[k].index2(); const double d = edges[k].distance(); if (l(i) > 0 && l(j) > 0) { split_edges.push_back(ordered_sample_pair(left_idx_map[i], left_idx_map[j], d)); modularity_threshold += d; } } modularity_threshold -= sum(left_node_degrees*sum(left_node_degrees))/(2*edge_sum); modularity_threshold /= 4*edge_sum; unsigned long num_left_clusters; std::vector left_labels; num_left_clusters = newman_cluster_helper(rnd,split_edges,left_node_degrees,left_Bdiag, edge_sum,left_labels,modularity_threshold, eps, max_iterations); // now load the other side into split_edges and cluster it as well split_edges.clear(); modularity_threshold = 0; for (unsigned long k = 0; k < edges.size(); ++k) { const unsigned long i = edges[k].index1(); const unsigned long j = edges[k].index2(); const double d = edges[k].distance(); if (l(i) < 0 && l(j) < 0) { split_edges.push_back(ordered_sample_pair(right_idx_map[i], right_idx_map[j], d)); modularity_threshold += d; } } modularity_threshold -= sum(right_node_degrees*sum(right_node_degrees))/(2*edge_sum); modularity_threshold /= 4*edge_sum; unsigned long num_right_clusters; std::vector right_labels; num_right_clusters = newman_cluster_helper(rnd,split_edges,right_node_degrees,right_Bdiag, edge_sum,right_labels,modularity_threshold, eps, max_iterations); // Now merge the labels from the two splits. labels.resize(node_degrees.size()); for (unsigned long i = 0; i < labels.size(); ++i) { // if this node was in the left split if (l(i) > 0) { labels[i] = left_labels[left_idx_map[i]]; } else // if this node was in the right split { labels[i] = right_labels[right_idx_map[i]] + num_left_clusters; } } return num_left_clusters + num_right_clusters; } else { labels.assign(node_degrees.size(),0); return 1; } } } // ---------------------------------------------------------------------------------------- inline unsigned long newman_cluster ( const std::vector& edges, std::vector& labels, const double eps = 1e-4, const unsigned long max_iterations = 2000 ) { // make sure requires clause is not broken DLIB_ASSERT(is_ordered_by_index(edges), "\t unsigned long newman_cluster()" << "\n\t Invalid inputs were given to this function" ); labels.clear(); if (edges.size() == 0) return 0; const unsigned long num_nodes = max_index_plus_one(edges); // compute the node_degrees vector, edge_sum value, and diag(B). matrix node_degrees(num_nodes); matrix Bdiag(num_nodes); Bdiag = 0; double edge_sum = 0; node_degrees = 0; for (unsigned long i = 0; i < edges.size(); ++i) { node_degrees(edges[i].index1()) += edges[i].distance(); edge_sum += edges[i].distance(); if (edges[i].index1() == edges[i].index2()) Bdiag(edges[i].index1()) += edges[i].distance(); } edge_sum /= 2; Bdiag -= squared(node_degrees)/(2*edge_sum); dlib::rand rnd; return impl::newman_cluster_helper(rnd,edges,node_degrees,Bdiag,edge_sum,labels,0,eps,max_iterations); } // ---------------------------------------------------------------------------------------- inline unsigned long newman_cluster ( const std::vector& edges, std::vector& labels, const double eps = 1e-4, const unsigned long max_iterations = 2000 ) { std::vector oedges; convert_unordered_to_ordered(edges, oedges); std::sort(oedges.begin(), oedges.end(), &order_by_index); return newman_cluster(oedges, labels, eps, max_iterations); } // ---------------------------------------------------------------------------------------- namespace impl { inline std::vector remap_labels ( const std::vector& labels, unsigned long& num_labels ) /*! ensures - This function takes labels and produces a mapping which maps elements of labels into the most compact range in [0, max] as possible. In particular, there won't be any unused integers in the mapped range. - #num_labels == the number of distinct values in labels. - returns a vector V such that: - V.size() == labels.size() - max(mat(V))+1 == num_labels. - for all valid i,j: - if (labels[i] == labels[j]) then - V[i] == V[j] - else - V[i] != V[j] !*/ { std::map temp; for (unsigned long i = 0; i < labels.size(); ++i) { if (temp.count(labels[i]) == 0) { const unsigned long next = temp.size(); temp[labels[i]] = next; } } num_labels = temp.size(); std::vector result(labels.size()); for (unsigned long i = 0; i < labels.size(); ++i) { result[i] = temp[labels[i]]; } return result; } } // ---------------------------------------------------------------------------------------- inline double modularity ( const std::vector& edges, const std::vector& labels ) { const unsigned long num_nodes = max_index_plus_one(edges); // make sure requires clause is not broken DLIB_ASSERT(labels.size() == num_nodes, "\t double modularity()" << "\n\t Invalid inputs were given to this function" ); unsigned long num_labels; const std::vector& labels_ = dlib::impl::remap_labels(labels,num_labels); std::vector cluster_sums(num_labels,0); std::vector k(num_nodes,0); double Q = 0; double m = 0; for (unsigned long i = 0; i < edges.size(); ++i) { const unsigned long n1 = edges[i].index1(); const unsigned long n2 = edges[i].index2(); k[n1] += edges[i].distance(); if (n1 != n2) k[n2] += edges[i].distance(); if (n1 != n2) m += edges[i].distance(); else m += edges[i].distance()/2; if (labels_[n1] == labels_[n2]) { if (n1 != n2) Q += 2*edges[i].distance(); else Q += edges[i].distance(); } } if (m == 0) return 0; for (unsigned long i = 0; i < labels_.size(); ++i) { cluster_sums[labels_[i]] += k[i]; } for (unsigned long i = 0; i < labels_.size(); ++i) { Q -= k[i]*cluster_sums[labels_[i]]/(2*m); } return 1.0/(2*m)*Q; } // ---------------------------------------------------------------------------------------- inline double modularity ( const std::vector& edges, const std::vector& labels ) { const unsigned long num_nodes = max_index_plus_one(edges); // make sure requires clause is not broken DLIB_ASSERT(labels.size() == num_nodes, "\t double modularity()" << "\n\t Invalid inputs were given to this function" ); unsigned long num_labels; const std::vector& labels_ = dlib::impl::remap_labels(labels,num_labels); std::vector cluster_sums(num_labels,0); std::vector k(num_nodes,0); double Q = 0; double m = 0; for (unsigned long i = 0; i < edges.size(); ++i) { const unsigned long n1 = edges[i].index1(); const unsigned long n2 = edges[i].index2(); k[n1] += edges[i].distance(); m += edges[i].distance(); if (labels_[n1] == labels_[n2]) { Q += edges[i].distance(); } } if (m == 0) return 0; for (unsigned long i = 0; i < labels_.size(); ++i) { cluster_sums[labels_[i]] += k[i]; } for (unsigned long i = 0; i < labels_.size(); ++i) { Q -= k[i]*cluster_sums[labels_[i]]/m; } return 1.0/m*Q; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_MODULARITY_ClUSTERING__H__ ================================================ FILE: dlib/clustering/modularity_clustering_abstract.h ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_ #ifdef DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_ #include #include "../graph_utils/ordered_sample_pair_abstract.h" #include "../graph_utils/sample_pair_abstract.h" namespace dlib { // ----------------------------------------------------------------------------------------- double modularity ( const std::vector& edges, const std::vector& labels ); /*! requires - labels.size() == max_index_plus_one(edges) - for all valid i: - 0 <= edges[i].distance() < std::numeric_limits::infinity() ensures - Interprets edges as an undirected graph. That is, it contains the edges on the said graph and the sample_pair::distance() values define the edge weights (larger values indicating a stronger edge connection between the nodes). - This function returns the modularity value obtained when the given input graph is broken into subgraphs according to the contents of labels. In particular, we say that two nodes with indices i and j are in the same subgraph or community if and only if labels[i] == labels[j]. - Duplicate edges are interpreted as if there had been just one edge with a distance value equal to the sum of all the duplicate edge's distance values. - See the paper Modularity and community structure in networks by M. E. J. Newman for a detailed definition. !*/ // ---------------------------------------------------------------------------------------- double modularity ( const std::vector& edges, const std::vector& labels ); /*! requires - labels.size() == max_index_plus_one(edges) - for all valid i: - 0 <= edges[i].distance() < std::numeric_limits::infinity() ensures - Interprets edges as a directed graph. That is, it contains the edges on the said graph and the ordered_sample_pair::distance() values define the edge weights (larger values indicating a stronger edge connection between the nodes). Note that, generally, modularity is only really defined for undirected graphs. Therefore, the "directed graph" given to this function should have symmetric edges between all nodes. The reason this function is provided at all is because sometimes a vector of ordered_sample_pair objects is a useful representation of an undirected graph. - This function returns the modularity value obtained when the given input graph is broken into subgraphs according to the contents of labels. In particular, we say that two nodes with indices i and j are in the same subgraph or community if and only if labels[i] == labels[j]. - Duplicate edges are interpreted as if there had been just one edge with a distance value equal to the sum of all the duplicate edge's distance values. - See the paper Modularity and community structure in networks by M. E. J. Newman for a detailed definition. !*/ // ---------------------------------------------------------------------------------------- unsigned long newman_cluster ( const std::vector& edges, std::vector& labels, const double eps = 1e-4, const unsigned long max_iterations = 2000 ); /*! requires - is_ordered_by_index(edges) == true - for all valid i: - 0 <= edges[i].distance() < std::numeric_limits::infinity() ensures - This function performs the clustering algorithm described in the paper Modularity and community structure in networks by M. E. J. Newman. - This function interprets edges as a graph and attempts to find the labeling that maximizes modularity(edges, #labels). - returns the number of clusters found. - #labels.size() == max_index_plus_one(edges) - for all valid i: - #labels[i] == the cluster ID of the node with index i in the graph. - 0 <= #labels[i] < the number of clusters found (i.e. cluster IDs are assigned contiguously and start at 0) - The main computation of the algorithm is involved in finding an eigenvector of a certain matrix. To do this, we use the power iteration. In particular, each time we try to find an eigenvector we will let the power iteration loop at most max_iterations times or until it reaches an accuracy of eps. Whichever comes first. !*/ // ---------------------------------------------------------------------------------------- unsigned long newman_cluster ( const std::vector& edges, std::vector& labels, const double eps = 1e-4, const unsigned long max_iterations = 2000 ); /*! requires - for all valid i: - 0 <= edges[i].distance() < std::numeric_limits::infinity() ensures - This function is identical to the above newman_cluster() routine except that it operates on a vector of sample_pair objects instead of ordered_sample_pairs. Therefore, this is simply a convenience routine. In particular, it is implemented by transforming the given edges into ordered_sample_pairs and then calling the newman_cluster() routine defined above. !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_ ================================================ FILE: dlib/clustering/spectral_cluster.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_SPECTRAL_CLUSTEr_H_ #define DLIB_SPECTRAL_CLUSTEr_H_ #include "spectral_cluster_abstract.h" #include #include "../matrix.h" #include "../svm/kkmeans.h" namespace dlib { template < typename kernel_type, typename vector_type > std::vector spectral_cluster ( const kernel_type& k, const vector_type& samples, const unsigned long num_clusters ) { DLIB_CASSERT(num_clusters > 0, "\t std::vector spectral_cluster(k,samples,num_clusters)" << "\n\t num_clusters can't be 0." ); if (num_clusters == 1) { // nothing to do, just assign everything to the 0 cluster. return std::vector(samples.size(), 0); } // compute the similarity matrix. matrix K(samples.size(), samples.size()); for (long r = 0; r < K.nr(); ++r) for (long c = r+1; c < K.nc(); ++c) K(r,c) = K(c,r) = (double)k(samples[r], samples[c]); for (long r = 0; r < K.nr(); ++r) K(r,r) = 0; matrix D(K.nr()); for (long r = 0; r < K.nr(); ++r) D(r) = sum(rowm(K,r)); D = sqrt(reciprocal(D)); K = diagm(D)*K*diagm(D); matrix u,w,v; // Use the normal SVD routine unless the matrix is really big, then use the fast // approximate version. if (K.nr() < 1000) svd3(K,u,w,v); else svd_fast(K,u,w,v, num_clusters+100, 5); // Pick out the eigenvectors associated with the largest eigenvalues. rsort_columns(v,w); v = colm(v, range(0,num_clusters-1)); // Now build the normalized spectral vectors, one for each input vector. std::vector > spec_samps, centers; for (long r = 0; r < v.nr(); ++r) { spec_samps.push_back(trans(rowm(v,r))); const double len = length(spec_samps.back()); if (len != 0) spec_samps.back() /= len; } // Finally do the K-means clustering pick_initial_centers(num_clusters, centers, spec_samps); find_clusters_using_kmeans(spec_samps, centers); // And then compute the cluster assignments based on the output of K-means. std::vector assignments; for (unsigned long i = 0; i < spec_samps.size(); ++i) assignments.push_back(nearest_center(centers, spec_samps[i])); return assignments; } } #endif // DLIB_SPECTRAL_CLUSTEr_H_ ================================================ FILE: dlib/clustering/spectral_cluster_abstract.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_ #ifdef DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_ #include namespace dlib { template < typename kernel_type, typename vector_type > std::vector spectral_cluster ( const kernel_type& k, const vector_type& samples, const unsigned long num_clusters ); /*! requires - samples must be something with an interface compatible with std::vector. - The following expression must evaluate to a double or float: k(samples[i], samples[j]) - num_clusters > 0 ensures - Performs the spectral clustering algorithm described in the paper: On spectral clustering: Analysis and an algorithm by Ng, Jordan, and Weiss. and returns the results. - This function clusters the input data samples into num_clusters clusters and returns a vector that indicates which cluster each sample falls into. In particular, we return an array A such that: - A.size() == samples.size() - A[i] == the cluster assignment of samples[i]. - for all valid i: 0 <= A[i] < num_clusters - The "similarity" of samples[i] with samples[j] is given by k(samples[i],samples[j]). This means that k() should output a number >= 0 and the number should be larger for samples that are more similar. !*/ } #endif // DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_ ================================================ FILE: dlib/clustering.h ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CLuSTERING_ #define DLIB_CLuSTERING_ #include "clustering/modularity_clustering.h" #include "clustering/chinese_whispers.h" #include "clustering/spectral_cluster.h" #include "clustering/bottom_up_cluster.h" #include "svm/kkmeans.h" #endif // DLIB_CLuSTERING_ ================================================ FILE: dlib/cmake ================================================ cmake_minimum_required(VERSION 3.8.0) add_subdirectory(${CMAKE_CURRENT_LIST_DIR} dlib_build) ================================================ FILE: dlib/cmake_utils/FindCUDNN.cmake ================================================ # Find the CUDNN libraries # # The following variables are optionally searched for defaults # CUDNN_ROOT: Base directory where CUDNN is found # CUDNN_INCLUDE_DIR: Directory where CUDNN header is searched for # CUDNN_LIBRARY: Directory where CUDNN library is searched for # CUDNN_STATIC: Are we looking for a static library? (default: no) # # The following are set after configuration is done: # CUDNN_FOUND # CUDNN_INCLUDE_PATH # CUDNN_LIBRARY_PATH # include(FindPackageHandleStandardArgs) set(CUDNN_ROOT $ENV{CUDNN_ROOT_DIR} CACHE PATH "Folder containing NVIDIA cuDNN") if (DEFINED $ENV{CUDNN_ROOT_DIR}) message(WARNING "CUDNN_ROOT_DIR is deprecated. Please set CUDNN_ROOT instead.") endif() list(APPEND CUDNN_ROOT $ENV{CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) # Compatible layer for CMake <3.12. CUDNN_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. list(APPEND CMAKE_PREFIX_PATH ${CUDNN_ROOT}) set(CUDNN_INCLUDE_DIR $ENV{CUDNN_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA cuDNN header files") set(CUDA_VERSION "${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}") find_path(CUDNN_INCLUDE_PATH cudnn.h HINTS ${CUDNN_INCLUDE_DIR} ENV CUDNN_INCLUDE_DIR ENV CUDNN_HOME PATHS /usr/local /usr/local/cuda "C:/Program Files/NVIDIA/CUDNN/*/include/${CUDA_VERSION}" "C:/Program Files/NVIDIA/CUDNN/*/include/*" ENV CPATH PATH_SUFFIXES cuda/include cuda include) option(CUDNN_STATIC "Look for static CUDNN" OFF) if (CUDNN_STATIC) set(CUDNN_LIBNAME "libcudnn_static.a") else() set(CUDNN_LIBNAME "cudnn") endif() set(CUDNN_LIBRARY $ENV{CUDNN_LIBRARY} CACHE PATH "Path to the cudnn library file (e.g., libcudnn.so)") if (CUDNN_LIBRARY MATCHES ".*cudnn_static.a" AND NOT CUDNN_STATIC) message(WARNING "CUDNN_LIBRARY points to a static library (${CUDNN_LIBRARY}) but CUDNN_STATIC is OFF.") endif() find_library(CUDNN_LIBRARY_PATH ${CUDNN_LIBNAME} PATHS ${CUDNN_LIBRARY} /usr/local /usr/local/cuda "C:/Program Files/NVIDIA/CUDNN/*/lib/${CUDA_VERSION}/x64" "C:/Program Files/NVIDIA/CUDNN/*/lib/${CUDA_VERSION}" "C:/Program Files/NVIDIA/CUDNN/*/lib/*" ENV LD_LIBRARY_PATH PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64) find_package_handle_standard_args(CUDNN DEFAULT_MSG CUDNN_LIBRARY_PATH CUDNN_INCLUDE_PATH) if(CUDNN_FOUND) # Get cuDNN version if(EXISTS ${CUDNN_INCLUDE_PATH}/cudnn_version.h) file(READ ${CUDNN_INCLUDE_PATH}/cudnn_version.h CUDNN_HEADER_CONTENTS) else() file(READ ${CUDNN_INCLUDE_PATH}/cudnn.h CUDNN_HEADER_CONTENTS) endif() string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)" CUDNN_VERSION_MAJOR "${CUDNN_HEADER_CONTENTS}") string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1" CUDNN_VERSION_MAJOR "${CUDNN_VERSION_MAJOR}") string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)" CUDNN_VERSION_MINOR "${CUDNN_HEADER_CONTENTS}") string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1" CUDNN_VERSION_MINOR "${CUDNN_VERSION_MINOR}") string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)" CUDNN_VERSION_PATCH "${CUDNN_HEADER_CONTENTS}") string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1" CUDNN_VERSION_PATCH "${CUDNN_VERSION_PATCH}") # Assemble cuDNN version if(NOT CUDNN_VERSION_MAJOR) set(CUDNN_VERSION "?") else() set(CUDNN_VERSION "${CUDNN_VERSION_MAJOR}.${CUDNN_VERSION_MINOR}.${CUDNN_VERSION_PATCH}") endif() endif() mark_as_advanced(CUDNN_ROOT CUDNN_INCLUDE_DIR CUDNN_LIBRARY CUDNN_VERSION) ================================================ FILE: dlib/cmake_utils/check_if_avx_instructions_executable_on_host.cmake ================================================ # This script checks if your compiler and host processor can generate and then run programs with AVX instructions. cmake_minimum_required(VERSION 3.10.0) # Don't rerun this script if its already been executed. if (DEFINED AVX_IS_AVAILABLE_ON_HOST) return() endif() # Set to false unless we find out otherwise in the code below. set(AVX_IS_AVAILABLE_ON_HOST 0) try_compile(test_for_avx_worked ${PROJECT_BINARY_DIR}/avx_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_avx avx_test) if(test_for_avx_worked) message (STATUS "AVX instructions can be executed by the host processor.") set(AVX_IS_AVAILABLE_ON_HOST 1) endif() ================================================ FILE: dlib/cmake_utils/check_if_neon_available.cmake ================================================ # This script checks if __ARM_NEON__ is defined for your compiler cmake_minimum_required(VERSION 3.10.0) # Don't rerun this script if its already been executed. if (DEFINED ARM_NEON_IS_AVAILABLE) return() endif() # Set to false unless we find out otherwise in the code below. set(ARM_NEON_IS_AVAILABLE 0) # test if __ARM_NEON__ is defined try_compile(test_for_neon_worked ${PROJECT_BINARY_DIR}/neon_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_neon neon_test) if(test_for_neon_worked) message (STATUS "__ARM_NEON__ defined.") set(ARM_NEON_IS_AVAILABLE 1) endif() ================================================ FILE: dlib/cmake_utils/check_if_sse4_instructions_executable_on_host.cmake ================================================ # This script checks if your compiler and host processor can generate and then run programs with SSE4 instructions. cmake_minimum_required(VERSION 3.10.0) # Don't rerun this script if its already been executed. if (DEFINED SSE4_IS_AVAILABLE_ON_HOST) return() endif() # Set to false unless we find out otherwise in the code below. set(SSE4_IS_AVAILABLE_ON_HOST 0) try_compile(test_for_sse4_worked ${PROJECT_BINARY_DIR}/sse4_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_sse4 sse4_test) if(test_for_sse4_worked) message (STATUS "SSE4 instructions can be executed by the host processor.") set(SSE4_IS_AVAILABLE_ON_HOST 1) endif() ================================================ FILE: dlib/cmake_utils/dlib.pc.in ================================================ libdir=@CMAKE_INSTALL_FULL_LIBDIR@ includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ Name: @PROJECT_NAME@ Description: Numerical and networking C++ library Version: @VERSION@ Libs: -L${libdir} -ldlib @pkg_config_dlib_needed_libraries@ Cflags: -I${includedir} @pkg_config_dlib_needed_includes@ ================================================ FILE: dlib/cmake_utils/dlibConfig.cmake.in ================================================ # =================================================================================== # The dlib CMake configuration file # # ** File generated automatically, do not modify ** # # Usage from an external project: # In your CMakeLists.txt, add these lines: # # find_package(dlib REQUIRED) # target_link_libraries(MY_TARGET_NAME dlib::dlib) # # =================================================================================== # Our library dependencies (contains definitions for IMPORTED targets) if(NOT TARGET dlib-shared AND NOT dlib_BINARY_DIR) # Compute paths get_filename_component(dlib_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) include("${dlib_CMAKE_DIR}/dlib.cmake") # Check if Threads::Threads target is required and find it if necessary get_target_property(dlib_deps_threads_check dlib::dlib INTERFACE_LINK_LIBRARIES) list(FIND dlib_deps_threads_check "Threads::Threads" dlib_deps_threads_idx) if (${dlib_deps_threads_idx} GREATER -1) if (NOT TARGET Threads) find_package(Threads REQUIRED) endif() endif() unset(dlib_deps_threads_idx) unset(dlib_deps_threads_check) endif() set(dlib_LIBRARIES dlib::dlib) set(dlib_LIBS dlib::dlib) set(dlib_INCLUDE_DIRS "@CMAKE_INSTALL_FULL_INCLUDEDIR@" "@dlib_needed_includes@") if (@DLIB_USE_CUDA@) find_package(CUDAToolkit) endif() mark_as_advanced(dlib_LIBRARIES) mark_as_advanced(dlib_LIBS) mark_as_advanced(dlib_INCLUDE_DIRS) # Mark these variables above as deprecated. function(__deprecated_var var access) if(access STREQUAL "READ_ACCESS") message(WARNING "The variable '${var}' is deprecated! Instead, simply use target_link_libraries(your_app dlib::dlib). See http://dlib.net/examples/CMakeLists.txt.html for an example.") endif() endfunction() variable_watch(dlib_LIBRARIES __deprecated_var) variable_watch(dlib_LIBS __deprecated_var) variable_watch(dlib_INCLUDE_DIRS __deprecated_var) ================================================ FILE: dlib/cmake_utils/find_blas.cmake ================================================ # # This is a CMake makefile. You can find the cmake utility and # information about it at http://www.cmake.org # # # This cmake file tries to find installed BLAS and LAPACK libraries. # It looks for an installed copy of the Intel MKL library first and then # attempts to find some other BLAS and LAPACK libraries if you don't have # the Intel MKL. # # blas_found - True if BLAS is available # lapack_found - True if LAPACK is available # found_intel_mkl - True if the Intel MKL library is available # found_intel_mkl_headers - True if Intel MKL headers are available # blas_libraries - link against these to use BLAS library # lapack_libraries - link against these to use LAPACK library # mkl_libraries - link against these to use the MKL library # mkl_include_dir - add to the include path to use the MKL library # openmp_libraries - Set to Intel's OpenMP library if and only if we # find the MKL. # setting this makes CMake allow normal looking if else statements SET(CMAKE_ALLOW_LOOSE_LOOP_CONSTRUCTS true) SET(blas_found 0) SET(lapack_found 0) SET(found_intel_mkl 0) SET(found_intel_mkl_headers 0) SET(lapack_with_underscore 0) SET(lapack_without_underscore 0) message(STATUS "Searching for BLAS and LAPACK") INCLUDE(CheckFunctionExists) if (UNIX OR MINGW) message(STATUS "Searching for BLAS and LAPACK") if (BUILDING_MATLAB_MEX_FILE) # # This commented out stuff would link directly to MATLAB's built in # BLAS and LAPACK. But it's better to not link to anything and do a #find_library(MATLAB_BLAS_LIBRARY mwblas PATHS ${MATLAB_LIB_FOLDERS} ) #find_library(MATLAB_LAPACK_LIBRARY mwlapack PATHS ${MATLAB_LIB_FOLDERS} ) #if (MATLAB_BLAS_LIBRARY AND MATLAB_LAPACK_LIBRARY) # add_subdirectory(external/cblas) # set(blas_libraries ${MATLAB_BLAS_LIBRARY} cblas ) # set(lapack_libraries ${MATLAB_LAPACK_LIBRARY} ) # set(blas_found 1) # set(lapack_found 1) # message(STATUS "Found MATLAB's BLAS and LAPACK libraries") #endif() # We need cblas since MATLAB doesn't provide cblas symbols. add_subdirectory(external/cblas) set(blas_libraries cblas ) set(blas_found 1) set(lapack_found 1) message(STATUS "Will link with MATLAB's BLAS and LAPACK at runtime (hopefully!)") ## Don't try to link to anything other than MATLAB's own internal blas ## and lapack libraries because doing so generally upsets MATLAB. So ## we just end here no matter what. return() endif() # First, search for libraries via pkg-config, which is the cleanest path find_package(PkgConfig) pkg_check_modules(BLAS_REFERENCE cblas) pkg_check_modules(LAPACK_REFERENCE lapack) # Make sure the cblas found by pkgconfig actually has cblas symbols. SET(CMAKE_REQUIRED_LIBRARIES "${BLAS_REFERENCE_LDFLAGS}") CHECK_FUNCTION_EXISTS(cblas_ddot PKGCFG_HAVE_CBLAS) if (BLAS_REFERENCE_FOUND AND LAPACK_REFERENCE_FOUND AND PKGCFG_HAVE_CBLAS) set(blas_libraries "${BLAS_REFERENCE_LDFLAGS}") set(lapack_libraries "${LAPACK_REFERENCE_LDFLAGS}") set(blas_found 1) set(lapack_found 1) set(REQUIRES_LIBS "${REQUIRES_LIBS} cblas lapack") message(STATUS "Found BLAS and LAPACK via pkg-config") return() endif() include(CheckTypeSize) check_type_size( "void*" SIZE_OF_VOID_PTR) if (SIZE_OF_VOID_PTR EQUAL 8) set( mkl_search_path /opt/intel/oneapi/mkl/latest/lib/intel64 /opt/intel/mkl/*/lib/em64t /opt/intel/mkl/lib/intel64 /opt/intel/lib/intel64 /opt/intel/mkl/lib /opt/intel/tbb/*/lib/em64t/gcc4.7 /opt/intel/tbb/lib/intel64/gcc4.7 /opt/intel/tbb/lib/gcc4.7 ) find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path}) mark_as_advanced(mkl_intel) else() set( mkl_search_path /opt/intel/oneapi/mkl/latest/lib/ia32 /opt/intel/mkl/*/lib/32 /opt/intel/mkl/lib/ia32 /opt/intel/lib/ia32 /opt/intel/tbb/*/lib/32/gcc4.7 /opt/intel/tbb/lib/ia32/gcc4.7 ) find_library(mkl_intel mkl_intel ${mkl_search_path}) mark_as_advanced(mkl_intel) endif() include(CheckLibraryExists) # Get mkl_include_dir set(mkl_include_search_path /opt/intel/oneapi/mkl/latest/include /opt/intel/mkl/include /opt/intel/include ) find_path(mkl_include_dir mkl_version.h ${mkl_include_search_path}) mark_as_advanced(mkl_include_dir) if(NOT DLIB_USE_MKL_SEQUENTIAL AND NOT DLIB_USE_MKL_WITH_TBB) # Search for the needed libraries from the MKL. We will try to link against the mkl_rt # file first since this way avoids linking bugs in some cases. find_library(mkl_rt mkl_rt ${mkl_search_path}) find_library(openmp_libraries iomp5 ${mkl_search_path}) mark_as_advanced(mkl_rt openmp_libraries) # if we found the MKL if (mkl_rt) set(mkl_libraries ${mkl_rt} ) set(blas_libraries ${mkl_rt} ) set(lapack_libraries ${mkl_rt} ) set(blas_found 1) set(lapack_found 1) set(found_intel_mkl 1) message(STATUS "Found Intel MKL BLAS/LAPACK library") endif() endif() if (NOT found_intel_mkl) # Search for the needed libraries from the MKL. This time try looking for a different # set of MKL files and try to link against those. find_library(mkl_core mkl_core ${mkl_search_path}) set(mkl_libs ${mkl_intel} ${mkl_core}) mark_as_advanced(mkl_libs mkl_intel mkl_core) if (DLIB_USE_MKL_WITH_TBB) find_library(mkl_tbb_thread mkl_tbb_thread ${mkl_search_path}) find_library(mkl_tbb tbb ${mkl_search_path}) mark_as_advanced(mkl_tbb_thread mkl_tbb) list(APPEND mkl_libs ${mkl_tbb_thread} ${mkl_tbb}) elseif (DLIB_USE_MKL_SEQUENTIAL) find_library(mkl_sequential mkl_sequential ${mkl_search_path}) mark_as_advanced(mkl_sequential) list(APPEND mkl_libs ${mkl_sequential}) else() find_library(mkl_thread mkl_intel_thread ${mkl_search_path}) find_library(mkl_iomp iomp5 ${mkl_search_path}) find_library(mkl_pthread pthread ${mkl_search_path}) mark_as_advanced(mkl_thread mkl_iomp mkl_pthread) list(APPEND mkl_libs ${mkl_thread} ${mkl_iomp} ${mkl_pthread}) endif() # If we found the MKL if (mkl_intel AND mkl_core AND ((mkl_tbb_thread AND mkl_tbb) OR (mkl_thread AND mkl_iomp AND mkl_pthread) OR mkl_sequential)) set(mkl_libraries ${mkl_libs}) set(blas_libraries ${mkl_libs}) set(lapack_libraries ${mkl_libs}) set(blas_found 1) set(lapack_found 1) set(found_intel_mkl 1) message(STATUS "Found Intel MKL BLAS/LAPACK library") endif() endif() if (found_intel_mkl AND mkl_include_dir) set(found_intel_mkl_headers 1) endif() # try to find some other LAPACK libraries if we didn't find the MKL set(extra_paths /usr/lib64 /usr/lib64/atlas-sse3 /usr/lib64/atlas-sse2 /usr/lib64/atlas /usr/lib /usr/lib/atlas-sse3 /usr/lib/atlas-sse2 /usr/lib/atlas /usr/lib/openblas-base /opt/OpenBLAS/lib $ENV{OPENBLAS_HOME}/lib ) if (NOT blas_found) find_library(cblas_lib NAMES openblasp openblas PATHS ${extra_paths}) if (cblas_lib) set(blas_libraries ${cblas_lib}) set(blas_found 1) message(STATUS "Found OpenBLAS library") set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries}) # If you compiled OpenBLAS with LAPACK in it then it should have the # sgetrf_single function in it. So if we find that function in # OpenBLAS then just use OpenBLAS's LAPACK. CHECK_FUNCTION_EXISTS(sgetrf_single OPENBLAS_HAS_LAPACK) if (OPENBLAS_HAS_LAPACK) message(STATUS "Using OpenBLAS's built in LAPACK") # set(lapack_libraries gfortran) set(lapack_found 1) endif() endif() mark_as_advanced( cblas_lib) endif() if (NOT lapack_found) find_library(lapack_lib NAMES lapack lapack-3 PATHS ${extra_paths}) if (lapack_lib) set(lapack_libraries ${lapack_lib}) set(lapack_found 1) message(STATUS "Found LAPACK library") endif() mark_as_advanced( lapack_lib) endif() # try to find some other BLAS libraries if we didn't find the MKL if (NOT blas_found) find_library(atlas_lib atlas PATHS ${extra_paths}) find_library(cblas_lib cblas PATHS ${extra_paths}) if (atlas_lib AND cblas_lib) set(blas_libraries ${atlas_lib} ${cblas_lib}) set(blas_found 1) message(STATUS "Found ATLAS BLAS library") endif() mark_as_advanced( atlas_lib cblas_lib) endif() # CentOS 7 atlas if (NOT blas_found) find_library(tatlas_lib tatlas PATHS ${extra_paths}) find_library(satlas_lib satlas PATHS ${extra_paths}) if (tatlas_lib AND satlas_lib ) set(blas_libraries ${tatlas_lib} ${satlas_lib}) set(blas_found 1) message(STATUS "Found ATLAS BLAS library") endif() mark_as_advanced( tatlas_lib satlas_lib) endif() if (NOT blas_found) find_library(cblas_lib cblas PATHS ${extra_paths}) if (cblas_lib) set(blas_libraries ${cblas_lib}) set(blas_found 1) message(STATUS "Found CBLAS library") endif() mark_as_advanced( cblas_lib) endif() if (NOT blas_found) find_library(generic_blas blas PATHS ${extra_paths}) if (generic_blas) set(blas_libraries ${generic_blas}) set(blas_found 1) message(STATUS "Found BLAS library") endif() mark_as_advanced( generic_blas) endif() # Make sure we really found a CBLAS library. That is, it needs to expose # the proper cblas link symbols. So here we test if one of them is present # and assume everything is good if it is. Note that we don't do this check if # we found the Intel MKL since for some reason CHECK_FUNCTION_EXISTS doesn't work # with it. But it's fine since the MKL should always have cblas. if (blas_found AND NOT found_intel_mkl) set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries}) CHECK_FUNCTION_EXISTS(cblas_ddot FOUND_BLAS_HAS_CBLAS) if (NOT FOUND_BLAS_HAS_CBLAS) message(STATUS "BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK") set(blas_found 0) set(lapack_found 0) endif() endif() elseif(WIN32 AND NOT MINGW) message(STATUS "Searching for BLAS and LAPACK") include(CheckTypeSize) check_type_size( "void*" SIZE_OF_VOID_PTR) if (SIZE_OF_VOID_PTR EQUAL 8) set( mkl_search_path "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/lib/intel64" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/tbb/lib/intel64/vc14" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/lib/intel64" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/lib/intel64" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/tbb/lib/intel64/vc14" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/tbb/lib/intel64/vc_mt" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/lib/intel64" "C:/Program Files (x86)/Intel/Composer XE/mkl/lib/intel64" "C:/Program Files (x86)/Intel/Composer XE/tbb/lib/intel64/vc14" "C:/Program Files (x86)/Intel/Composer XE/compiler/lib/intel64" "C:/Program Files/Intel/Composer XE/mkl/lib/intel64" "C:/Program Files/Intel/Composer XE/tbb/lib/intel64/vc14" "C:/Program Files/Intel/Composer XE/compiler/lib/intel64" "C:/Program Files (x86)/Intel/oneAPI/mkl/*/lib" "C:/Program Files (x86)/Intel/oneAPI/compiler/*/lib" "C:/Program Files (x86)/Intel/oneAPI/mkl/*/lib/intel64" "C:/Program Files (x86)/Intel/oneAPI/compiler/*/windows/compiler/lib/intel64_win" ) set (mkl_redist_path "C:/Program Files (x86)/Intel/oneAPI/compiler/*/bin" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/redist/intel64/compiler" "C:/Program Files (x86)/Intel/oneAPI/compiler/*/windows/redist/intel64_win/compiler" ) find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path}) else() set( mkl_search_path "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/lib/ia32" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/tbb/lib/ia32/vc14" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/lib/ia32" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/lib/ia32" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/tbb/lib/ia32/vc14" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/tbb/lib/ia32/vc_mt" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/lib/ia32" "C:/Program Files (x86)/Intel/Composer XE/mkl/lib/ia32" "C:/Program Files (x86)/Intel/Composer XE/tbb/lib/ia32/vc14" "C:/Program Files (x86)/Intel/Composer XE/compiler/lib/ia32" "C:/Program Files/Intel/Composer XE/mkl/lib/ia32" "C:/Program Files/Intel/Composer XE/tbb/lib/ia32/vc14" "C:/Program Files/Intel/Composer XE/compiler/lib/ia32" "C:/Program Files (x86)/Intel/oneAPI/mkl/*/lib/ia32" "C:/Program Files (x86)/Intel/oneAPI/compiler/*/windows/compiler/lib/ia32_win" ) set (mkl_redist_path "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/redist/ia32/compiler" "C:/Program Files (x86)/Intel/oneAPI/compiler/*/windows/redist/ia32_win/compiler" ) find_library(mkl_intel mkl_intel_c ${mkl_search_path}) endif() # Get mkl_include_dir set(mkl_include_search_path "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/include" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/include" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/include" "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/include" "C:/Program Files (x86)/Intel/Composer XE/mkl/include" "C:/Program Files (x86)/Intel/Composer XE/compiler/include" "C:/Program Files/Intel/Composer XE/mkl/include" "C:/Program Files/Intel/Composer XE/compiler/include" "C:/Program Files (x86)/Intel/oneAPI/mkl/*/include" ) find_path(mkl_include_dir mkl_version.h ${mkl_include_search_path}) mark_as_advanced(mkl_include_dir) # Search for the needed libraries from the MKL. find_library(mkl_core mkl_core ${mkl_search_path}) set(mkl_libs ${mkl_intel} ${mkl_core}) mark_as_advanced(mkl_libs mkl_intel mkl_core) if (DLIB_USE_MKL_WITH_TBB) find_library(mkl_tbb_thread mkl_tbb_thread ${mkl_search_path}) find_library(mkl_tbb tbb ${mkl_search_path}) mark_as_advanced(mkl_tbb_thread mkl_tbb) list(APPEND mkl_libs ${mkl_tbb_thread} ${mkl_tbb}) elseif (DLIB_USE_MKL_SEQUENTIAL) find_library(mkl_sequential mkl_sequential ${mkl_search_path}) mark_as_advanced(mkl_sequential) list(APPEND mkl_libs ${mkl_sequential}) else() find_library(mkl_thread mkl_intel_thread ${mkl_search_path}) mark_as_advanced(mkl_thread) if (mkl_thread) find_library(mkl_iomp libiomp5md ${mkl_search_path}) mark_as_advanced(mkl_iomp) list(APPEND mkl_libs ${mkl_thread} ${mkl_iomp}) # See if we can find the dll that goes with this, so we can copy it to # the output folder, since a very large number of windows users don't # understand that they need to add the Intel MKL's folders to their # PATH to use the Intel MKL. They then complain on the dlib forums. # Copying the Intel MKL dlls to the output directory removes the need # to add the Intel MKL to the PATH. find_file(mkl_iomp_dll "libiomp5md.dll" ${mkl_redist_path}) if (mkl_iomp_dll) message(STATUS "FOUND libiomp5md.dll: ${mkl_iomp_dll}") endif() endif() endif() # If we found the MKL if (mkl_intel AND mkl_core AND ((mkl_tbb_thread AND mkl_tbb) OR mkl_sequential OR (mkl_thread AND mkl_iomp))) set(blas_libraries ${mkl_libs}) set(lapack_libraries ${mkl_libs}) set(blas_found 1) set(lapack_found 1) set(found_intel_mkl 1) message(STATUS "Found Intel MKL BLAS/LAPACK library") # Make sure the version of the Intel MKL we found is compatible with # the compiler we are using. One way to do this check is to see if we can # link to it right now. set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries}) CHECK_FUNCTION_EXISTS(cblas_ddot MKL_HAS_CBLAS) if (NOT MKL_HAS_CBLAS) message("BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK") set(blas_found 0) set(lapack_found 0) endif() endif() if (found_intel_mkl AND mkl_include_dir) set(found_intel_mkl_headers 1) endif() endif() # When all else fails use CMake's built in functions to find BLAS and LAPACK if (NOT blas_found) find_package(BLAS QUIET) if (${BLAS_FOUND}) set(blas_libraries ${BLAS_LIBRARIES}) set(blas_found 1) if (NOT lapack_found) find_package(LAPACK QUIET) if (${LAPACK_FOUND}) set(lapack_libraries ${LAPACK_LIBRARIES}) set(lapack_found 1) endif() endif() endif() endif() # If using lapack, determine whether to mangle functions if (lapack_found) include(CheckFortranFunctionExists) set(CMAKE_REQUIRED_LIBRARIES ${lapack_libraries}) check_function_exists("sgesv" LAPACK_FOUND_C_UNMANGLED) check_function_exists("sgesv_" LAPACK_FOUND_C_MANGLED) if (CMAKE_Fortran_COMPILER_LOADED) check_fortran_function_exists("sgesv" LAPACK_FOUND_FORTRAN_UNMANGLED) check_fortran_function_exists("sgesv_" LAPACK_FOUND_FORTRAN_MANGLED) endif () if (LAPACK_FOUND_C_MANGLED OR LAPACK_FOUND_FORTRAN_MANGLED) set(lapack_with_underscore 1) elseif (LAPACK_FOUND_C_UNMANGLED OR LAPACK_FOUND_FORTRAN_UNMANGLED) set(lapack_without_underscore 1) endif () endif() if (UNIX OR MINGW) if (NOT blas_found) message(" *****************************************************************************") message(" *** No BLAS library found so using dlib's built in BLAS. However, if you ***") message(" *** install an optimized BLAS such as OpenBLAS or the Intel MKL your code ***") message(" *** will run faster. On Ubuntu you can install OpenBLAS by executing: ***") message(" *** sudo apt-get install libopenblas-dev liblapack-dev ***") message(" *** Or you can easily install OpenBLAS from source by downloading the ***") message(" *** source tar file from http://www.openblas.net, extracting it, and ***") message(" *** running: ***") message(" *** make; sudo make install ***") message(" *****************************************************************************") endif() endif() ================================================ FILE: dlib/cmake_utils/find_ffmpeg.cmake ================================================ cmake_minimum_required(VERSION 3.10.0) message(STATUS "Searching for FFMPEG/LIBAV") find_package(PkgConfig) if (PkgConfig_FOUND) pkg_check_modules(FFMPEG IMPORTED_TARGET libavdevice libavfilter libavformat libavcodec libswresample libswscale libavutil ) if (FFMPEG_FOUND) message(STATUS "Found FFMPEG/LIBAV via pkg-config in `${FFMPEG_LIBRARY_DIRS}`") else() message(" *****************************************************************************") message(" *** No FFMPEG/LIBAV libraries found. ***") message(" *** On Ubuntu you can install them by executing ***") message(" *** sudo apt install libavdevice-dev libavfilter-dev libavformat-dev ***") message(" *** sudo apt install libavcodec-dev libswresample-dev libswscale-dev ***") message(" *** sudo apt install libavutil-dev ***") message(" *****************************************************************************") endif() else() message(STATUS "PkgConfig could not be found, FFMPEG won't be available") set(FFMPEG_FOUND 0) endif() ================================================ FILE: dlib/cmake_utils/find_libjpeg.cmake ================================================ #This script just runs CMake's built in JPEG finding tool. But it also checks that the #copy of libjpeg that cmake finds actually builds and links. cmake_minimum_required(VERSION 3.10.0) if (BUILDING_PYTHON_IN_MSVC) # Never use any system copy of libjpeg when building python in visual studio set(JPEG_FOUND 0) return() endif() # Don't rerun this script if its already been executed. if (DEFINED JPEG_FOUND) return() endif() find_package(JPEG QUIET) if(JPEG_FOUND) set(JPEG_TEST_CMAKE_FLAGS "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}" "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}") try_compile(test_for_libjpeg_worked ${PROJECT_BINARY_DIR}/test_for_libjpeg_build ${CMAKE_CURRENT_LIST_DIR}/test_for_libjpeg test_if_libjpeg_is_broken CMAKE_FLAGS "${JPEG_TEST_CMAKE_FLAGS}") message (STATUS "Found system copy of libjpeg: ${JPEG_LIBRARY}") if(NOT test_for_libjpeg_worked) set(JPEG_FOUND 0) message (STATUS "System copy of libjpeg is broken or too old. Will build our own libjpeg and use that instead.") endif() endif() ================================================ FILE: dlib/cmake_utils/find_libjxl.cmake ================================================ #============================================================================= # Find JPEG XL library #============================================================================= # Find the native JPEG XL headers and libraries. # # JXL_INCLUDE_DIRS - where to find jxl/decode_cxx.h, etc. # JXL_LIBRARIES - List of libraries when using jxl. # JXL_FOUND - True if jxl is found. #============================================================================= # Look for the header file. message(STATUS "Searching for JPEG XL") find_package(PkgConfig) if (PkgConfig_FOUND) pkg_check_modules(JXL IMPORTED_TARGET libjxl libjxl_cms libjxl_threads) if (JXL_FOUND) message(STATUS "Found libjxl via pkg-config in `${JXL_LIBRARY_DIRS}`") else() message(" *****************************************************************************") message(" *** No JPEG XL libraries found. ***") message(" *** On Ubuntu 23.04 and newer you can install them by executing ***") message(" *** sudo apt install libjxl-dev ***") message(" *** ***") message(" *** Otherwise, you can find precompiled packages here: ***") message(" *** https://github.com/libjxl/libjxl/releases ***") message(" *****************************************************************************") endif() else() message(STATUS "PkgConfig could not be found, JPEG XL support won't be available") set(JXL_FOUND 0) endif() if(JXL_FOUND) set(JXL_TEST_CMAKE_FLAGS "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}" "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}") try_compile(test_for_libjxl_worked ${PROJECT_BINARY_DIR}/test_for_libjxl_build ${CMAKE_CURRENT_LIST_DIR}/test_for_libjxl test_if_libjxl_is_broken CMAKE_FLAGS "${JXL_TEST_CMAKE_FLAGS}") if(NOT test_for_libjxl_worked) set(JXL_FOUND 0) message (STATUS "System copy of libjxl is either too old or broken. Will disable JPEG XL support.") endif() endif() ================================================ FILE: dlib/cmake_utils/find_libpng.cmake ================================================ #This script just runs CMake's built in PNG finding tool. But it also checks that the #copy of libpng that cmake finds actually builds and links. cmake_minimum_required(VERSION 3.10.0) if (BUILDING_PYTHON_IN_MSVC) # Never use any system copy of libpng when building python in visual studio set(PNG_FOUND 0) return() endif() # Don't rerun this script if its already been executed. if (DEFINED PNG_FOUND) return() endif() find_package(PNG QUIET) if(PNG_FOUND) set(PNG_TEST_CMAKE_FLAGS "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}" "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}") try_compile(test_for_libpng_worked ${PROJECT_BINARY_DIR}/test_for_libpng_build ${CMAKE_CURRENT_LIST_DIR}/test_for_libpng test_if_libpng_is_broken CMAKE_FLAGS "${PNG_TEST_CMAKE_FLAGS}") message (STATUS "Found system copy of libpng: ${PNG_LIBRARIES}") if(NOT test_for_libpng_worked) set(PNG_FOUND 0) message (STATUS "System copy of libpng is broken. Will build our own libpng and use that instead.") endif() endif() ================================================ FILE: dlib/cmake_utils/find_libwebp.cmake ================================================ #============================================================================= # Find WebP library # From OpenCV #============================================================================= # Find the native WebP headers and libraries. # # WEBP_INCLUDE_DIRS - where to find webp/decode.h, etc. # WEBP_LIBRARIES - List of libraries when using webp. # WEBP_FOUND - True if webp is found. #============================================================================= # Look for the header file. unset(WEBP_FOUND) find_path(WEBP_INCLUDE_DIR NAMES webp/decode.h) if(NOT WEBP_INCLUDE_DIR) unset(WEBP_FOUND) else() mark_as_advanced(WEBP_INCLUDE_DIR) # Look for the library. find_library(WEBP_LIBRARY NAMES webp) mark_as_advanced(WEBP_LIBRARY) # handle the QUIETLY and REQUIRED arguments and set WEBP_FOUND to TRUE if # all listed variables are TRUE include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) find_package_handle_standard_args(WebP DEFAULT_MSG WEBP_LIBRARY WEBP_INCLUDE_DIR) set(WEBP_LIBRARIES ${WEBP_LIBRARY}) set(WEBP_INCLUDE_DIRS ${WEBP_INCLUDE_DIR}) endif() if(WEBP_FOUND) set(WEBP_TEST_CMAKE_FLAGS "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}" "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}") try_compile(test_for_libwebp_worked ${PROJECT_BINARY_DIR}/test_for_libwebp_build ${CMAKE_CURRENT_LIST_DIR}/test_for_libwebp test_if_libwebp_is_broken CMAKE_FLAGS "${WEBP_TEST_CMAKE_FLAGS}") if(NOT test_for_libwebp_worked) set(WEBP_FOUND 0) message (STATUS "System copy of libwebp is either too old or broken. Will disable WebP support.") endif() endif() ================================================ FILE: dlib/cmake_utils/release_build_by_default ================================================ #set default build type to Release if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel." FORCE) endif() ================================================ FILE: dlib/cmake_utils/set_compiler_specific_options.cmake ================================================ cmake_minimum_required(VERSION 3.10.0) # Check if we are being built as part of a pybind11 module. if (COMMAND pybind11_add_module) # For python users, enable SSE4 and AVX if they have these instructions. include(${CMAKE_CURRENT_LIST_DIR}/check_if_sse4_instructions_executable_on_host.cmake) if (SSE4_IS_AVAILABLE_ON_HOST) set(USE_SSE4_INSTRUCTIONS ON CACHE BOOL "Compile your program with SSE4 instructions") endif() include(${CMAKE_CURRENT_LIST_DIR}/check_if_avx_instructions_executable_on_host.cmake) if (AVX_IS_AVAILABLE_ON_HOST) set(USE_AVX_INSTRUCTIONS ON CACHE BOOL "Compile your program with AVX instructions") endif() include(${CMAKE_CURRENT_LIST_DIR}/check_if_neon_available.cmake) if (ARM_NEON_IS_AVAILABLE) set(USE_NEON_INSTRUCTIONS ON CACHE BOOL "Compile your program with ARM-NEON instructions") endif() endif() set(gcc_like_compilers GNU Clang Intel) set(intel_archs x86_64 i386 i686 AMD64 amd64 x86) # Setup some options to allow a user to enable SSE and AVX instruction use. if ((";${gcc_like_compilers};" MATCHES ";${CMAKE_CXX_COMPILER_ID};") AND (";${intel_archs};" MATCHES ";${CMAKE_SYSTEM_PROCESSOR};") AND NOT USE_AUTO_VECTOR) option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" OFF) option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF) option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF) if(USE_AVX_INSTRUCTIONS) list(APPEND active_compile_opts -mavx) message(STATUS "Enabling AVX instructions") elseif (USE_SSE4_INSTRUCTIONS) list(APPEND active_compile_opts -msse4) message(STATUS "Enabling SSE4 instructions") elseif(USE_SSE2_INSTRUCTIONS) list(APPEND active_compile_opts -msse2) message(STATUS "Enabling SSE2 instructions") endif() elseif (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") # else if using Visual Studio # Use SSE2 by default when using Visual Studio. option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" ON) option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF) option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF) include(CheckTypeSize) check_type_size( "void*" SIZE_OF_VOID_PTR) if(USE_AVX_INSTRUCTIONS) list(APPEND active_compile_opts /arch:AVX) message(STATUS "Enabling AVX instructions") elseif (USE_SSE4_INSTRUCTIONS) # Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes. # So only give it when we are doing a 32 bit build. if (SIZE_OF_VOID_PTR EQUAL 4) list(APPEND active_compile_opts /arch:SSE2) endif() message(STATUS "Enabling SSE4 instructions") list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE2") list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE3") list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE41") elseif(USE_SSE2_INSTRUCTIONS) # Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes. # So only give it when we are doing a 32 bit build. if (SIZE_OF_VOID_PTR EQUAL 4) list(APPEND active_compile_opts /arch:SSE2) endif() message(STATUS "Enabling SSE2 instructions") list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE2") endif() elseif((";${gcc_like_compilers};" MATCHES ";${CMAKE_CXX_COMPILER_ID};") AND ("${CMAKE_SYSTEM_PROCESSOR}" MATCHES "^arm")) option(USE_NEON_INSTRUCTIONS "Compile your program with ARM-NEON instructions" OFF) if(USE_NEON_INSTRUCTIONS) list(APPEND active_compile_opts -mfpu=neon) message(STATUS "Enabling ARM-NEON instructions") endif() endif() if (CMAKE_COMPILER_IS_GNUCXX) # By default, g++ won't warn or error if you forget to return a value in a # function which requires you to do so. This option makes it give a warning # for doing this. list(APPEND active_compile_opts "-Wreturn-type") endif() if ("Clang" MATCHES ${CMAKE_CXX_COMPILER_ID} AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0.0) # Clang 6 had a default template recursion depth of 256. This was changed to 1024 in Clang 7. # It must be increased on Clang 6 and below to ensure that the dnn examples don't error out. list(APPEND active_compile_opts "-ftemplate-depth=500") endif() if (MSVC) # By default Visual Studio does not support .obj files with more than 65k sections. # However, code generated by file_to_code_ex and code using DNN module can have # them. So this flag enables > 65k sections, but produces .obj files # that will not be readable by VS 2005. list(APPEND active_compile_opts "/bigobj") # Build dlib with all cores. Don't propagate the setting to client programs # though since they might compile large translation units that use too much # RAM. list(APPEND active_compile_opts_private "/MP") if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 3.3) # Clang can compile all Dlib's code at Windows platform. Tested with Clang 5 list(APPEND active_compile_opts -Xclang) list(APPEND active_compile_opts -fcxx-exceptions) endif() endif() ================================================ FILE: dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake ================================================ # Including this cmake script into your cmake project will cause visual studio # to build your project against the static C runtime. cmake_minimum_required(VERSION 3.10.0) if (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") option (DLIB_FORCE_MSVC_STATIC_RUNTIME "use static runtime" ON) if (DLIB_FORCE_MSVC_STATIC_RUNTIME) foreach(flag_var CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) if(${flag_var} MATCHES "/MD") string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") endif() endforeach(flag_var) endif () endif() ================================================ FILE: dlib/cmake_utils/test_for_avx/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.10.0) project(avx_test) set(USE_AVX_INSTRUCTIONS ON CACHE BOOL "Use AVX instructions") # Pull this in since it sets the AVX compile options by putting that kind of stuff into the active_compile_opts list. include(../set_compiler_specific_options.cmake) try_run(run_result compile_result ${PROJECT_BINARY_DIR}/avx_test_try_run_build ${CMAKE_CURRENT_LIST_DIR}/avx_test.cpp COMPILE_DEFINITIONS ${active_compile_opts}) message(STATUS "run_result = ${run_result}") message(STATUS "compile_result = ${compile_result}") if ("${run_result}" EQUAL 0 AND compile_result) message(STATUS "Ran AVX test program successfully, you have AVX available.") else() message(STATUS "Unable to run AVX test program, you don't seem to have AVX instructions available.") # make this build fail so that calling try_compile statements will error in this case. add_library(make_this_build_fail ${CMAKE_CURRENT_LIST_DIR}/this_file_doesnt_compile.cpp) endif() ================================================ FILE: dlib/cmake_utils/test_for_avx/avx_test.cpp ================================================ #include int main() { __m256 x; x = _mm256_set1_ps(1.23); x = _mm256_add_ps(x,x); return 0; } // ------------------------------------------------------------------------------------ ================================================ FILE: dlib/cmake_utils/test_for_avx/this_file_doesnt_compile.cpp ================================================ #error "This file doesn't compile!" ================================================ FILE: dlib/cmake_utils/test_for_libjpeg/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.10.0) project(test_if_libjpeg_is_broken) find_package(JPEG) include_directories(${JPEG_INCLUDE_DIR}) add_executable(libjpeg_test libjpeg_test.cpp) target_link_libraries(libjpeg_test ${JPEG_LIBRARY}) ================================================ FILE: dlib/cmake_utils/test_for_libjpeg/libjpeg_test.cpp ================================================ // Copyright (C) 2019 Davis E. King (davis@dlib.net), Nils Labugt // License: Boost Software License See LICENSE.txt for the full license. #include #include #include #include #include struct jpeg_loader_error_mgr { jpeg_error_mgr pub; jmp_buf setjmp_buffer; }; void jpeg_loader_error_exit (j_common_ptr cinfo) { jpeg_loader_error_mgr* myerr = (jpeg_loader_error_mgr*) cinfo->err; longjmp(myerr->setjmp_buffer, 1); } // This code doesn't really make a lot of sense. It's just calling all the libjpeg functions to make // sure they can be compiled and linked. int main() { std::cerr << "This program is just for build system testing. Don't actually run it." << std::endl; abort(); FILE *fp = fopen("whatever.jpg", "rb" ); jpeg_decompress_struct cinfo; jpeg_loader_error_mgr jerr; cinfo.err = jpeg_std_error(&jerr.pub); jerr.pub.error_exit = jpeg_loader_error_exit; setjmp(jerr.setjmp_buffer); jpeg_create_decompress(&cinfo); jpeg_stdio_src(&cinfo, fp); if (false) { unsigned char imgbuffer[1234]; jpeg_mem_src(&cinfo, imgbuffer, sizeof(imgbuffer)); } jpeg_read_header(&cinfo, TRUE); jpeg_start_decompress(&cinfo); unsigned long height_ = cinfo.output_height; unsigned long width_ = cinfo.output_width; unsigned long output_components_ = cinfo.output_components; unsigned char* rows[123]; while (cinfo.output_scanline < cinfo.output_height) { jpeg_read_scanlines(&cinfo, &rows[cinfo.output_scanline], 100); } jpeg_finish_decompress(&cinfo); jpeg_destroy_decompress(&cinfo); fclose( fp ); } ================================================ FILE: dlib/cmake_utils/test_for_libjxl/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.10.0) project(test_if_libjxl_is_broken) include_directories(${JXL_INCLUDE_DIR}) add_executable(libjxl_test libjxl_test.cpp) target_link_libraries(libjxl_test ${JXL_LIBRARY}) ================================================ FILE: dlib/cmake_utils/test_for_libjxl/libjxl_test.cpp ================================================ // Copyright (C) 2023 Davis E. King (davis@dlib.net), Adrià Arrufat // License: Boost Software License See LICENSE.txt for the full license. #include #include #include #include #include // This code doesn't really make a lot of sense. It's just calling all the libjpeg functions to make // sure they can be compiled and linked. int main() { std::cerr << "This program is just for build system testing. Don't actually run it." << std::endl; std::abort(); auto enc = JxlEncoderMake(nullptr); auto dec = JxlDecoderMake(nullptr); auto runner = JxlResizableParallelRunnerMake(nullptr); } ================================================ FILE: dlib/cmake_utils/test_for_libpng/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.10.0) project(test_if_libpng_is_broken) find_package(PNG) include_directories(${PNG_INCLUDE_DIR}) add_executable(libpng_test libpng_test.cpp) target_link_libraries(libpng_test ${PNG_LIBRARIES}) ================================================ FILE: dlib/cmake_utils/test_for_libpng/libpng_test.cpp ================================================ // Copyright (C) 2019 Davis E. King (davis@dlib.net), Nils Labugt // License: Boost Software License See LICENSE.txt for the full license. #include #include #include #include #include void png_loader_user_error_fn_silent(png_structp png_struct, png_const_charp ) { longjmp(png_jmpbuf(png_struct),1); } void png_loader_user_warning_fn_silent(png_structp , png_const_charp ) { } // This code doesn't really make a lot of sense. It's just calling all the libpng functions to make // sure they can be compiled and linked. int main() { std::cerr << "This program is just for build system testing. Don't actually run it." << std::endl; abort(); png_bytep* row_pointers_; png_structp png_ptr_; png_infop info_ptr_; png_infop end_info_; FILE *fp = fopen( "whatever.png", "rb" ); png_byte sig[8]; fread( sig, 1, 8, fp ); png_sig_cmp( sig, 0, 8 ); png_ptr_ = png_create_read_struct( PNG_LIBPNG_VER_STRING, NULL, &png_loader_user_error_fn_silent, &png_loader_user_warning_fn_silent ); png_get_header_ver(NULL); info_ptr_ = png_create_info_struct( png_ptr_ ); end_info_ = png_create_info_struct( png_ptr_ ); setjmp(png_jmpbuf(png_ptr_)); png_set_palette_to_rgb(png_ptr_); png_init_io( png_ptr_, fp ); png_set_sig_bytes( png_ptr_, 8 ); // flags force one byte per channel output int png_transforms = PNG_TRANSFORM_PACKING; png_read_png( png_ptr_, info_ptr_, png_transforms, NULL ); png_get_image_height( png_ptr_, info_ptr_ ); png_get_image_width( png_ptr_, info_ptr_ ); png_get_bit_depth( png_ptr_, info_ptr_ ); png_get_color_type( png_ptr_, info_ptr_ ); png_get_rows( png_ptr_, info_ptr_ ); fclose(fp); png_destroy_read_struct(&png_ptr_, &info_ptr_, &end_info_); } ================================================ FILE: dlib/cmake_utils/test_for_libwebp/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.10.0) project(test_if_libwebp_is_broken) include_directories(${WEBP_INCLUDE_DIR}) add_executable(libwebp_test libwebp_test.cpp) target_link_libraries(libwebp_test ${WEBP_LIBRARY}) ================================================ FILE: dlib/cmake_utils/test_for_libwebp/libwebp_test.cpp ================================================ // Copyright (C) 2019 Davis E. King (davis@dlib.net), Nils Labugt // License: Boost Software License See LICENSE.txt for the full license. #include #include #include // This code doesn't really make a lot of sense. It's just calling all the libjpeg functions to make // sure they can be compiled and linked. int main() { std::cerr << "This program is just for build system testing. Don't actually run it." << std::endl; std::abort(); uint8_t* data; size_t output_size = 0; int width, height, stride; float quality; output_size = WebPEncodeRGB(data, width, height, stride, quality, &data); WebPDecodeRGBInto(data, output_size, data, output_size, stride); WebPFree(data); } ================================================ FILE: dlib/cmake_utils/test_for_neon/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.10.0) project(neon_test) add_library(neon_test STATIC neon_test.cpp ) ================================================ FILE: dlib/cmake_utils/test_for_neon/neon_test.cpp ================================================ #ifdef __ARM_NEON__ #else #error "No NEON" #endif int main(){} // ------------------------------------------------------------------------------------ ================================================ FILE: dlib/cmake_utils/test_for_sse4/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.10.0) project(sse4_test) set(USE_SSE4_INSTRUCTIONS ON CACHE BOOL "Use SSE4 instructions") # Pull this in since it sets the SSE4 compile options by putting that kind of stuff into the active_compile_opts list. include(../set_compiler_specific_options.cmake) try_run(run_result compile_result ${PROJECT_BINARY_DIR}/sse4_test_try_run_build ${CMAKE_CURRENT_LIST_DIR}/sse4_test.cpp COMPILE_DEFINITIONS ${active_compile_opts}) message(STATUS "run_result = ${run_result}") message(STATUS "compile_result = ${compile_result}") if ("${run_result}" EQUAL 0 AND compile_result) message(STATUS "Ran SSE4 test program successfully, you have SSE4 available.") else() message(STATUS "Unable to run SSE4 test program, you don't seem to have SSE4 instructions available.") # make this build fail so that calling try_compile statements will error in this case. add_library(make_this_build_fail ${CMAKE_CURRENT_LIST_DIR}/this_file_doesnt_compile.cpp) endif() ================================================ FILE: dlib/cmake_utils/test_for_sse4/sse4_test.cpp ================================================ #include #include #include #include // SSE3 #include #include // SSE4 int main() { __m128 x; x = _mm_set1_ps(1.23); x = _mm_ceil_ps(x); return 0; } // ------------------------------------------------------------------------------------ ================================================ FILE: dlib/cmake_utils/test_for_sse4/this_file_doesnt_compile.cpp ================================================ #error "This file doesn't compile!" ================================================ FILE: dlib/cmd_line_parser/cmd_line_parser_check_1.h ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CMD_LINE_PARSER_CHECk_1_ #define DLIB_CMD_LINE_PARSER_CHECk_1_ #include "cmd_line_parser_kernel_abstract.h" #include #include #include "../string.h" #include namespace dlib { template < typename clp_base > class cmd_line_parser_check_1 : public clp_base { /*! This extension doesn't add any state. !*/ public: typedef typename clp_base::char_type char_type; typedef typename clp_base::string_type string_type; // ------------------------------------------------------------------------------------ class cmd_line_check_error : public dlib::error { friend class cmd_line_parser_check_1; cmd_line_check_error( error_type t, const string_type& opt_, const string_type& arg_ ) : dlib::error(t), opt(opt_), opt2(), arg(arg_), required_opts() { set_info_string(); } cmd_line_check_error( error_type t, const string_type& opt_, const string_type& opt2_, int // this is just to make this constructor different from the one above ) : dlib::error(t), opt(opt_), opt2(opt2_), arg(), required_opts() { set_info_string(); } cmd_line_check_error ( error_type t, const string_type& opt_, const std::vector& vect ) : dlib::error(t), opt(opt_), opt2(), arg(), required_opts(vect) { set_info_string(); } cmd_line_check_error( error_type t, const string_type& opt_ ) : dlib::error(t), opt(opt_), opt2(), arg(), required_opts() { set_info_string(); } ~cmd_line_check_error() noexcept {} void set_info_string ( ) { std::ostringstream sout; switch (type) { case EINVALID_OPTION_ARG: sout << "Command line error: '" << narrow(arg) << "' is not a valid argument to " << "the '" << narrow(opt) << "' option."; break; case EMISSING_REQUIRED_OPTION: if (required_opts.size() == 1) { sout << "Command line error: The '" << narrow(opt) << "' option requires the presence of " << "the '" << required_opts[0] << "' option."; } else { sout << "Command line error: The '" << narrow(opt) << "' option requires the presence of " << "one of the following options: "; for (unsigned long i = 0; i < required_opts.size(); ++i) { if (i == required_opts.size()-2) sout << "'" << required_opts[i] << "' or "; else if (i == required_opts.size()-1) sout << "'" << required_opts[i] << "'."; else sout << "'" << required_opts[i] << "', "; } } break; case EINCOMPATIBLE_OPTIONS: sout << "Command line error: The '" << narrow(opt) << "' and '" << narrow(opt2) << "' options cannot be given together on the command line."; break; case EMULTIPLE_OCCURANCES: sout << "Command line error: The '" << narrow(opt) << "' option can only " << "be given on the command line once."; break; default: sout << "Command line error."; break; } const_cast(info) = wrap_string(sout.str(),0,0); } public: const string_type opt; const string_type opt2; const string_type arg; const std::vector required_opts; }; // ------------------------------------------------------------------------------------ template < typename T > void check_option_arg_type ( const string_type& option_name ) const; template < typename T > void check_option_arg_range ( const string_type& option_name, const T& first, const T& last ) const; template < typename T, size_t length > void check_option_arg_range ( const string_type& option_name, const T (&arg_set)[length] ) const; template < size_t length > void check_option_arg_range ( const string_type& option_name, const char_type* (&arg_set)[length] ) const; template < size_t length > void check_incompatible_options ( const char_type* (&option_set)[length] ) const; template < size_t length > void check_one_time_options ( const char_type* (&option_set)[length] ) const; void check_incompatible_options ( const string_type& option_name1, const string_type& option_name2 ) const; void check_sub_option ( const string_type& parent_option, const string_type& sub_option ) const; template < size_t length > void check_sub_options ( const string_type& parent_option, const char_type* (&sub_option_set)[length] ) const; template < size_t length > void check_sub_options ( const char_type* (&parent_option_set)[length], const string_type& sub_option ) const; template < size_t parent_length, size_t sub_length > void check_sub_options ( const char_type* (&parent_option_set)[parent_length], const char_type* (&sub_option_set)[sub_length] ) const; }; template < typename clp_base > inline void swap ( cmd_line_parser_check_1& a, cmd_line_parser_check_1& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template template void cmd_line_parser_check_1:: check_option_arg_type ( const string_type& option_name ) const { try { const typename clp_base::option_type& opt = this->option(option_name); const unsigned long number_of_arguments = opt.number_of_arguments(); const unsigned long count = opt.count(); for (unsigned long i = 0; i < number_of_arguments; ++i) { for (unsigned long j = 0; j < count; ++j) { string_cast(opt.argument(i,j)); } } } catch (string_cast_error& e) { throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info); } } // ---------------------------------------------------------------------------------------- template template void cmd_line_parser_check_1:: check_option_arg_range ( const string_type& option_name, const T& first, const T& last ) const { try { const typename clp_base::option_type& opt = this->option(option_name); const unsigned long number_of_arguments = opt.number_of_arguments(); const unsigned long count = opt.count(); for (unsigned long i = 0; i < number_of_arguments; ++i) { for (unsigned long j = 0; j < count; ++j) { T temp(string_cast(opt.argument(i,j))); if (temp < first || last < temp) { throw cmd_line_check_error( EINVALID_OPTION_ARG, option_name, opt.argument(i,j) ); } } } } catch (string_cast_error& e) { throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info); } } // ---------------------------------------------------------------------------------------- template template < typename T, size_t length > void cmd_line_parser_check_1:: check_option_arg_range ( const string_type& option_name, const T (&arg_set)[length] ) const { try { const typename clp_base::option_type& opt = this->option(option_name); const unsigned long number_of_arguments = opt.number_of_arguments(); const unsigned long count = opt.count(); for (unsigned long i = 0; i < number_of_arguments; ++i) { for (unsigned long j = 0; j < count; ++j) { T temp(string_cast(opt.argument(i,j))); size_t k = 0; for (; k < length; ++k) { if (arg_set[k] == temp) break; } if (k == length) { throw cmd_line_check_error( EINVALID_OPTION_ARG, option_name, opt.argument(i,j) ); } } } } catch (string_cast_error& e) { throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info); } } // ---------------------------------------------------------------------------------------- template template < size_t length > void cmd_line_parser_check_1:: check_option_arg_range ( const string_type& option_name, const char_type* (&arg_set)[length] ) const { const typename clp_base::option_type& opt = this->option(option_name); const unsigned long number_of_arguments = opt.number_of_arguments(); const unsigned long count = opt.count(); for (unsigned long i = 0; i < number_of_arguments; ++i) { for (unsigned long j = 0; j < count; ++j) { size_t k = 0; for (; k < length; ++k) { if (arg_set[k] == opt.argument(i,j)) break; } if (k == length) { throw cmd_line_check_error( EINVALID_OPTION_ARG, option_name, opt.argument(i,j) ); } } } } // ---------------------------------------------------------------------------------------- template template < size_t length > void cmd_line_parser_check_1:: check_incompatible_options ( const char_type* (&option_set)[length] ) const { for (size_t i = 0; i < length; ++i) { for (size_t j = i+1; j < length; ++j) { if (this->option(option_set[i]).count() > 0 && this->option(option_set[j]).count() > 0 ) { throw cmd_line_check_error( EINCOMPATIBLE_OPTIONS, option_set[i], option_set[j], 0 // this argument has no meaning and is only here to make this // call different from the other constructor ); } } } } // ---------------------------------------------------------------------------------------- template void cmd_line_parser_check_1:: check_incompatible_options ( const string_type& option_name1, const string_type& option_name2 ) const { if (this->option(option_name1).count() > 0 && this->option(option_name2).count() > 0 ) { throw cmd_line_check_error( EINCOMPATIBLE_OPTIONS, option_name1, option_name2, 0 // this argument has no meaning and is only here to make this // call different from the other constructor ); } } // ---------------------------------------------------------------------------------------- template void cmd_line_parser_check_1:: check_sub_option ( const string_type& parent_option, const string_type& sub_option ) const { if (this->option(parent_option).count() == 0) { if (this->option(sub_option).count() != 0) { std::vector vect; vect.resize(1); vect[0] = parent_option; throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option, vect); } } } // ---------------------------------------------------------------------------------------- template template < size_t length > void cmd_line_parser_check_1:: check_sub_options ( const string_type& parent_option, const char_type* (&sub_option_set)[length] ) const { if (this->option(parent_option).count() == 0) { size_t i = 0; for (; i < length; ++i) { if (this->option(sub_option_set[i]).count() > 0) break; } if (i != length) { std::vector vect; vect.resize(1); vect[0] = parent_option; throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option_set[i], vect); } } } // ---------------------------------------------------------------------------------------- template template < size_t length > void cmd_line_parser_check_1:: check_sub_options ( const char_type* (&parent_option_set)[length], const string_type& sub_option ) const { // first check if the sub_option is present if (this->option(sub_option).count() > 0) { // now check if any of the parents are present bool parents_present = false; for (size_t i = 0; i < length; ++i) { if (this->option(parent_option_set[i]).count() > 0) { parents_present = true; break; } } if (!parents_present) { std::vector vect(parent_option_set, parent_option_set+length); throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option, vect); } } } // ---------------------------------------------------------------------------------------- template template < size_t parent_length, size_t sub_length > void cmd_line_parser_check_1:: check_sub_options ( const char_type* (&parent_option_set)[parent_length], const char_type* (&sub_option_set)[sub_length] ) const { // first check if any of the parent options are present bool parents_present = false; for (size_t i = 0; i < parent_length; ++i) { if (this->option(parent_option_set[i]).count() > 0) { parents_present = true; break; } } if (!parents_present) { // none of these sub options should be present size_t i = 0; for (; i < sub_length; ++i) { if (this->option(sub_option_set[i]).count() > 0) break; } if (i != sub_length) { std::vector vect(parent_option_set, parent_option_set+parent_length); throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option_set[i], vect); } } } // ---------------------------------------------------------------------------------------- template template < size_t length > void cmd_line_parser_check_1:: check_one_time_options ( const char_type* (&option_set)[length] ) const { size_t i = 0; for (; i < length; ++i) { if (this->option(option_set[i]).count() > 1) break; } if (i != length) { throw cmd_line_check_error( EMULTIPLE_OCCURANCES, option_set[i] ); } } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CMD_LINE_PARSER_CHECk_1_ ================================================ FILE: dlib/cmd_line_parser/cmd_line_parser_check_c.h ================================================ // Copyright (C) 2006 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CMD_LINE_PARSER_CHECk_C_ #define DLIB_CMD_LINE_PARSER_CHECk_C_ #include "cmd_line_parser_kernel_abstract.h" #include "../algs.h" #include "../assert.h" #include #include "../interfaces/cmd_line_parser_option.h" #include "../string.h" namespace dlib { template < typename clp_check > class cmd_line_parser_check_c : public clp_check { public: typedef typename clp_check::char_type char_type; typedef typename clp_check::string_type string_type; template < typename T > void check_option_arg_type ( const string_type& option_name ) const; template < typename T > void check_option_arg_range ( const string_type& option_name, const T& first, const T& last ) const; template < typename T, size_t length > void check_option_arg_range ( const string_type& option_name, const T (&arg_set)[length] ) const; template < size_t length > void check_option_arg_range ( const string_type& option_name, const char_type* (&arg_set)[length] ) const; template < size_t length > void check_incompatible_options ( const char_type* (&option_set)[length] ) const; template < size_t length > void check_one_time_options ( const char_type* (&option_set)[length] ) const; void check_incompatible_options ( const string_type& option_name1, const string_type& option_name2 ) const; void check_sub_option ( const string_type& parent_option, const string_type& sub_option ) const; template < size_t length > void check_sub_options ( const string_type& parent_option, const char_type* (&sub_option_set)[length] ) const; template < size_t length > void check_sub_options ( const char_type* (&parent_option_set)[length], const string_type& sub_option ) const; template < size_t parent_length, size_t sub_length > void check_sub_options ( const char_type* (&parent_option_set)[parent_length], const char_type* (&sub_option_set)[sub_length] ) const; }; template < typename clp_check > inline void swap ( cmd_line_parser_check_c& a, cmd_line_parser_check_c& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template template void cmd_line_parser_check_c:: check_option_arg_type ( const string_type& option_name ) const { COMPILE_TIME_ASSERT(is_pointer_type::value == false); // make sure requires clause is not broken DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name), "\tvoid cmd_line_parser_check::check_option_arg_type()" << "\n\tYou must have already parsed the command line and option_name must be valid." << "\n\tthis: " << this << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") << "\n\toption_name: " << option_name ); clp_check::template check_option_arg_type(option_name); } // ---------------------------------------------------------------------------------------- template template void cmd_line_parser_check_c:: check_option_arg_range ( const string_type& option_name, const T& first, const T& last ) const { COMPILE_TIME_ASSERT(is_pointer_type::value == false); // make sure requires clause is not broken DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name) && first <= last, "\tvoid cmd_line_parser_check::check_option_arg_range()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") << "\n\toption_name: " << option_name << "\n\tfirst: " << first << "\n\tlast: " << last ); clp_check::check_option_arg_range(option_name,first,last); } // ---------------------------------------------------------------------------------------- template template < typename T, size_t length > void cmd_line_parser_check_c:: check_option_arg_range ( const string_type& option_name, const T (&arg_set)[length] ) const { COMPILE_TIME_ASSERT(is_pointer_type::value == false); // make sure requires clause is not broken DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name), "\tvoid cmd_line_parser_check::check_option_arg_range()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") << "\n\toption_name: " << option_name ); clp_check::check_option_arg_range(option_name,arg_set); } // ---------------------------------------------------------------------------------------- template template < size_t length > void cmd_line_parser_check_c:: check_option_arg_range ( const string_type& option_name, const char_type* (&arg_set)[length] ) const { // make sure requires clause is not broken DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name), "\tvoid cmd_line_parser_check::check_option_arg_range()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") << "\n\toption_name: " << option_name ); clp_check::check_option_arg_range(option_name,arg_set); } // ---------------------------------------------------------------------------------------- template template < size_t length > void cmd_line_parser_check_c:: check_incompatible_options ( const char_type* (&option_set)[length] ) const { // make sure requires clause is not broken for (size_t i = 0; i < length; ++i) { DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_set[i]), "\tvoid cmd_line_parser_check::check_incompatible_options()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\toption_is_defined(option_set[i]): " << ((this->option_is_defined(option_set[i]))?"true":"false") << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") << "\n\toption_set[i]: " << option_set[i] << "\n\ti: " << static_cast(i) ); } clp_check::check_incompatible_options(option_set); } // ---------------------------------------------------------------------------------------- template void cmd_line_parser_check_c:: check_incompatible_options ( const string_type& option_name1, const string_type& option_name2 ) const { // make sure requires clause is not broken DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name1) && this->option_is_defined(option_name2), "\tvoid cmd_line_parser_check::check_incompatible_options()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\toption_is_defined(option_name1): " << ((this->option_is_defined(option_name1))?"true":"false") << "\n\toption_is_defined(option_name2): " << ((this->option_is_defined(option_name2))?"true":"false") << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") << "\n\toption_name1: " << option_name1 << "\n\toption_name2: " << option_name2 ); clp_check::check_incompatible_options(option_name1,option_name2); } // ---------------------------------------------------------------------------------------- template void cmd_line_parser_check_c:: check_sub_option ( const string_type& parent_option, const string_type& sub_option ) const { // make sure requires clause is not broken DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(parent_option) && this->option_is_defined(sub_option), "\tvoid cmd_line_parser_check::check_sub_option()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\tparsed_line(): " << this->parsed_line() << "\n\toption_is_defined(parent_option): " << this->option_is_defined(parent_option) << "\n\toption_is_defined(sub_option): " << this->option_is_defined(sub_option) << "\n\tparent_option: " << parent_option << "\n\tsub_option: " << sub_option ); clp_check::check_sub_option(parent_option,sub_option); } // ---------------------------------------------------------------------------------------- template template < size_t length > void cmd_line_parser_check_c:: check_sub_options ( const string_type& parent_option, const char_type* (&sub_option_set)[length] ) const { // make sure requires clause is not broken for (size_t i = 0; i < length; ++i) { DLIB_CASSERT( this->option_is_defined(sub_option_set[i]), "\tvoid cmd_line_parser_check::check_sub_options()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\toption_is_defined(sub_option_set[i]): " << ((this->option_is_defined(sub_option_set[i]))?"true":"false") << "\n\tsub_option_set[i]: " << sub_option_set[i] << "\n\ti: " << static_cast(i) ); } DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(parent_option), "\tvoid cmd_line_parser_check::check_sub_options()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\toption_is_defined(parent_option): " << ((this->option_is_defined(parent_option))?"true":"false") << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") << "\n\tparent_option: " << parent_option ); clp_check::check_sub_options(parent_option,sub_option_set); } // ---------------------------------------------------------------------------------------- template template < size_t length > void cmd_line_parser_check_c:: check_sub_options ( const char_type* (&parent_option_set)[length], const string_type& sub_option ) const { // make sure requires clause is not broken for (size_t i = 0; i < length; ++i) { DLIB_CASSERT( this->option_is_defined(parent_option_set[i]), "\tvoid cmd_line_parser_check::check_sub_options()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\toption_is_defined(parent_option_set[i]): " << ((this->option_is_defined(parent_option_set[i]))?"true":"false") << "\n\tparent_option_set[i]: " << parent_option_set[i] << "\n\ti: " << static_cast(i) ); } DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(sub_option), "\tvoid cmd_line_parser_check::check_sub_options()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\toption_is_defined(sub_option): " << ((this->option_is_defined(sub_option))?"true":"false") << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") << "\n\tsub_option: " << sub_option ); clp_check::check_sub_options(parent_option_set,sub_option); } // ---------------------------------------------------------------------------------------- template template < size_t parent_length, size_t sub_length > void cmd_line_parser_check_c:: check_sub_options ( const char_type* (&parent_option_set)[parent_length], const char_type* (&sub_option_set)[sub_length] ) const { // make sure requires clause is not broken for (size_t i = 0; i < sub_length; ++i) { DLIB_CASSERT( this->option_is_defined(sub_option_set[i]), "\tvoid cmd_line_parser_check::check_sub_options()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\toption_is_defined(sub_option_set[i]): " << ((this->option_is_defined(sub_option_set[i]))?"true":"false") << "\n\tsub_option_set[i]: " << sub_option_set[i] << "\n\ti: " << static_cast(i) ); } for (size_t i = 0; i < parent_length; ++i) { DLIB_CASSERT( this->option_is_defined(parent_option_set[i]), "\tvoid cmd_line_parser_check::check_parent_options()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\toption_is_defined(parent_option_set[i]): " << ((this->option_is_defined(parent_option_set[i]))?"true":"false") << "\n\tparent_option_set[i]: " << parent_option_set[i] << "\n\ti: " << static_cast(i) ); } DLIB_CASSERT( this->parsed_line() == true , "\tvoid cmd_line_parser_check::check_sub_options()" << "\n\tYou must have parsed the command line before you call this function." << "\n\tthis: " << this << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") ); clp_check::check_sub_options(parent_option_set,sub_option_set); } // ---------------------------------------------------------------------------------------- template template < size_t length > void cmd_line_parser_check_c:: check_one_time_options ( const char_type* (&option_set)[length] ) const { // make sure requires clause is not broken for (size_t i = 0; i < length; ++i) { DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_set[i]), "\tvoid cmd_line_parser_check::check_one_time_options()" << "\n\tSee the requires clause for this function." << "\n\tthis: " << this << "\n\toption_is_defined(option_set[i]): " << ((this->option_is_defined(option_set[i]))?"true":"false") << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") << "\n\toption_set[i]: " << option_set[i] << "\n\ti: " << static_cast(i) ); } clp_check::check_one_time_options(option_set); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CMD_LINE_PARSER_CHECk_C_ ================================================ FILE: dlib/cmd_line_parser/cmd_line_parser_kernel_1.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CMD_LINE_PARSER_KERNEl_1_ #define DLIB_CMD_LINE_PARSER_KERNEl_1_ #include "cmd_line_parser_kernel_abstract.h" #include "../algs.h" #include #include #include "../interfaces/enumerable.h" #include "../interfaces/cmd_line_parser_option.h" #include "../assert.h" #include "../string.h" namespace dlib { template < typename charT, typename map, typename sequence, typename sequence2 > class cmd_line_parser_kernel_1 : public enumerable > { /*! REQUIREMENTS ON map is an implementation of map/map_kernel_abstract.h is instantiated to map items of type std::basic_string to void* REQUIREMENTS ON sequence is an implementation of sequence/sequence_kernel_abstract.h and is instantiated with std::basic_string REQUIREMENTS ON sequence2 is an implementation of sequence/sequence_kernel_abstract.h and is instantiated with std::basic_string* INITIAL VALUE options.size() == 0 argv.size() == 0 have_parsed_line == false CONVENTION have_parsed_line == parsed_line() argv[index] == operator[](index) argv.size() == number_of_arguments() *((option_t*)options[name]) == option(name) options.is_in_domain(name) == option_is_defined(name) !*/ public: typedef charT char_type; typedef std::basic_string string_type; typedef cmd_line_parser_option option_type; // exception class class cmd_line_parse_error : public dlib::error { void set_info_string ( ) { std::ostringstream sout; switch (type) { case EINVALID_OPTION: sout << "Command line error: '" << narrow(item) << "' is not a valid option."; break; case ETOO_FEW_ARGS: if (num > 1) { sout << "Command line error: The '" << narrow(item) << "' option requires " << num << " arguments."; } else { sout << "Command line error: The '" << narrow(item) << "' option requires " << num << " argument."; } break; case ETOO_MANY_ARGS: sout << "Command line error: The '" << narrow(item) << "' option does not take any arguments.\n"; break; default: sout << "Command line error."; break; } const_cast(info) = wrap_string(sout.str(),0,0); } public: cmd_line_parse_error( error_type t, const std::basic_string& _item ) : dlib::error(t), item(_item), num(0) { set_info_string();} cmd_line_parse_error( error_type t, const std::basic_string& _item, unsigned long _num ) : dlib::error(t), item(_item), num(_num) { set_info_string();} cmd_line_parse_error( ) : dlib::error(), item(), num(0) { set_info_string();} ~cmd_line_parse_error() noexcept {} const std::basic_string item; const unsigned long num; }; private: class option_t : public cmd_line_parser_option { /*! INITIAL VALUE options.size() == 0 CONVENTION name_ == name() description_ == description() number_of_arguments_ == number_of_arguments() options[N][arg] == argument(arg,N) num_present == count() !*/ friend class cmd_line_parser_kernel_1; public: const std::basic_string& name ( ) const { return name_; } const std::basic_string& group_name ( ) const { return group_name_; } const std::basic_string& description ( ) const { return description_; } unsigned long number_of_arguments( ) const { return number_of_arguments_; } unsigned long count ( ) const { return num_present; } const std::basic_string& argument ( unsigned long arg, unsigned long N ) const { // make sure requires clause is not broken DLIB_CASSERT( N < count() && arg < number_of_arguments(), "\tconst string_type& cmd_line_parser_option::argument(unsigned long,unsigned long)" << "\n\tInvalid arguments were given to this function." << "\n\tthis: " << this << "\n\tN: " << N << "\n\targ: " << arg << "\n\tname(): " << narrow(name()) << "\n\tcount(): " << count() << "\n\tnumber_of_arguments(): " << number_of_arguments() ); return options[N][arg]; } protected: option_t ( ) : num_present(0) {} ~option_t() { clear(); } private: void clear() /*! ensures - #count() == 0 - clears everything out of options and frees memory !*/ { for (unsigned long i = 0; i < options.size(); ++i) { delete [] options[i]; } options.clear(); num_present = 0; } // data members std::basic_string name_; std::basic_string group_name_; std::basic_string description_; sequence2 options; unsigned long number_of_arguments_; unsigned long num_present; // restricted functions option_t(option_t&); // copy constructor option_t& operator=(option_t&); // assignment operator }; // -------------------------- public: cmd_line_parser_kernel_1 ( ); virtual ~cmd_line_parser_kernel_1 ( ); void clear( ); void parse ( int argc, const charT** argv ); void parse ( int argc, charT** argv ) { parse(argc, const_cast(argv)); } bool parsed_line( ) const; bool option_is_defined ( const string_type& name ) const; void add_option ( const string_type& name, const string_type& description, unsigned long number_of_arguments = 0 ); void set_group_name ( const string_type& group_name ); string_type get_group_name ( ) const { return group_name; } const cmd_line_parser_option& option ( const string_type& name ) const; unsigned long number_of_arguments( ) const; const string_type& operator[] ( unsigned long index ) const; void swap ( cmd_line_parser_kernel_1& item ); // functions from the enumerable interface bool at_start ( ) const { return options.at_start(); } void reset ( ) const { options.reset(); } bool current_element_valid ( ) const { return options.current_element_valid(); } const cmd_line_parser_option& element ( ) const { return *static_cast*>(options.element().value()); } cmd_line_parser_option& element ( ) { return *static_cast*>(options.element().value()); } bool move_next ( ) const { return options.move_next(); } size_t size ( ) const { return options.size(); } private: // data members map options; sequence argv; bool have_parsed_line; string_type group_name; // restricted functions cmd_line_parser_kernel_1(cmd_line_parser_kernel_1&); // copy constructor cmd_line_parser_kernel_1& operator=(cmd_line_parser_kernel_1&); // assignment operator }; // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > inline void swap ( cmd_line_parser_kernel_1& a, cmd_line_parser_kernel_1& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > cmd_line_parser_kernel_1:: cmd_line_parser_kernel_1 ( ) : have_parsed_line(false) { } // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > cmd_line_parser_kernel_1:: ~cmd_line_parser_kernel_1 ( ) { // delete all option_t objects in options options.reset(); while (options.move_next()) { delete static_cast(options.element().value()); } } // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > void cmd_line_parser_kernel_1:: clear( ) { have_parsed_line = false; argv.clear(); // delete all option_t objects in options options.reset(); while (options.move_next()) { delete static_cast(options.element().value()); } options.clear(); reset(); } // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > void cmd_line_parser_kernel_1:: parse ( int argc_, const charT** argv ) { // make sure there aren't any arguments hanging around from the last time // parse was called this->argv.clear(); // make sure that the options have been cleared of any arguments since // the last time parse() was called if (have_parsed_line) { options.reset(); while (options.move_next()) { static_cast(options.element().value())->clear(); } options.reset(); } // this tells us if we have seen -- on the command line all by itself // or not. bool escape = false; const unsigned long argc = static_cast(argc_); try { for (unsigned long i = 1; i < argc; ++i) { if (argv[i][0] == _dT(charT,'-') && !escape) { // we are looking at the start of an option // -------------------------------------------------------------------- if (argv[i][1] == _dT(charT,'-')) { // we are looking at the start of a "long named" option string_type temp = &argv[i][2]; string_type first_argument; typename string_type::size_type pos = temp.find_first_of(_dT(charT,'=')); // This variable will be 1 if there is an argument supplied via the = sign // and 0 otherwise. unsigned long extra_argument = 0; if (pos != string_type::npos) { // there should be an extra argument extra_argument = 1; first_argument = temp.substr(pos+1); temp = temp.substr(0,pos); } // make sure this name is defined if (!options.is_in_domain(temp)) { // the long name is not a valid option if (argv[i][2] == _dT(charT,'\0')) { // there was nothing after the -- on the command line escape = true; continue; } else { // there was something after the command line but it // wasn't a valid option throw cmd_line_parse_error(EINVALID_OPTION,temp); } } option_t* o = static_cast(options[temp]); // check the number of arguments after this option and make sure // it is correct if (argc + extra_argument <= o->number_of_arguments() + i) { // there are too few arguments throw cmd_line_parse_error(ETOO_FEW_ARGS,temp,o->number_of_arguments()); } if (extra_argument && first_argument.size() == 0 ) { // if there would be exactly the right number of arguments if // the first_argument wasn't empty if (argc == o->number_of_arguments() + i) throw cmd_line_parse_error(ETOO_FEW_ARGS,temp,o->number_of_arguments()); else { // in this case we just ignore the trailing = and parse everything // the same. extra_argument = 0; } } // you can't force an option that doesn't have any arguments to take // one by using the --option=arg syntax if (extra_argument == 1 && o->number_of_arguments() == 0) { throw cmd_line_parse_error(ETOO_MANY_ARGS,temp); } // at this point we know that the option is ok and we should // populate its options object if (o->number_of_arguments() > 0) { string_type* stemp = new string_type[o->number_of_arguments()]; unsigned long j = 0; // add the argument after the = sign if one is present if (extra_argument) { stemp[0] = first_argument; ++j; } for (; j < o->number_of_arguments(); ++j) { stemp[j] = argv[i+j+1-extra_argument]; } o->options.add(o->options.size(),stemp); } o->num_present += 1; // adjust the value of i to account for the arguments to // this option i += o->number_of_arguments() - extra_argument; } // -------------------------------------------------------------------- else { // we are looking at the start of a list of a single char options // make sure there is something in this string other than - if (argv[i][1] == _dT(charT,'\0')) { throw cmd_line_parse_error(); } string_type temp = &argv[i][1]; const typename string_type::size_type num = temp.size(); for (unsigned long k = 0; k < num; ++k) { string_type name; // Doing this instead of name = temp[k] seems to avoid a bug in g++ (Ubuntu/Linaro 4.5.2-8ubuntu4) 4.5.2 // which results in name[0] having the wrong value. name.resize(1); name[0] = temp[k]; // make sure this name is defined if (!options.is_in_domain(name)) { // the name is not a valid option throw cmd_line_parse_error(EINVALID_OPTION,name); } option_t* o = static_cast(options[name]); // if there are chars immediately following this option int delta = 0; if (num != k+1) { delta = 1; } // check the number of arguments after this option and make sure // it is correct if (argc + delta <= o->number_of_arguments() + i) { // there are too few arguments std::ostringstream sout; throw cmd_line_parse_error(ETOO_FEW_ARGS,name,o->number_of_arguments()); } o->num_present += 1; // at this point we know that the option is ok and we should // populate its options object if (o->number_of_arguments() > 0) { string_type* stemp = new string_type[o->number_of_arguments()]; if (delta == 1) { temp = &argv[i][2+k]; k = (unsigned long)num; // this ensures that the argument to this // option isn't going to be treated as a // list of options stemp[0] = temp; } for (unsigned long j = 0; j < o->number_of_arguments()-delta; ++j) { stemp[j+delta] = argv[i+j+1]; } o->options.add(o->options.size(),stemp); // adjust the value of i to account for the arguments to // this option i += o->number_of_arguments()-delta; } } // for (unsigned long k = 0; k < num; ++k) } // -------------------------------------------------------------------- } else { // this is just a normal argument string_type temp = argv[i]; this->argv.add(this->argv.size(),temp); } } have_parsed_line = true; } catch (...) { have_parsed_line = false; // clear all the option objects options.reset(); while (options.move_next()) { static_cast(options.element().value())->clear(); } options.reset(); throw; } } // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > bool cmd_line_parser_kernel_1:: parsed_line( ) const { return have_parsed_line; } // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > bool cmd_line_parser_kernel_1:: option_is_defined ( const string_type& name ) const { return options.is_in_domain(name); } // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > void cmd_line_parser_kernel_1:: set_group_name ( const string_type& group_name_ ) { group_name = group_name_; } // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > void cmd_line_parser_kernel_1:: add_option ( const string_type& name, const string_type& description, unsigned long number_of_arguments ) { option_t* temp = new option_t; try { temp->name_ = name; temp->group_name_ = group_name; temp->description_ = description; temp->number_of_arguments_ = number_of_arguments; void* t = temp; string_type n(name); options.add(n,t); }catch (...) { delete temp; throw;} } // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > const cmd_line_parser_option& cmd_line_parser_kernel_1:: option ( const string_type& name ) const { return *static_cast*>(options[name]); } // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > unsigned long cmd_line_parser_kernel_1:: number_of_arguments( ) const { return argv.size(); } // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > const std::basic_string& cmd_line_parser_kernel_1:: operator[] ( unsigned long index ) const { return argv[index]; } // ---------------------------------------------------------------------------------------- template < typename charT, typename map, typename sequence, typename sequence2 > void cmd_line_parser_kernel_1:: swap ( cmd_line_parser_kernel_1& item ) { options.swap(item.options); argv.swap(item.argv); exchange(have_parsed_line,item.have_parsed_line); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CMD_LINE_PARSER_KERNEl_1_ ================================================ FILE: dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_ #ifdef DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_ #include "../algs.h" #include #include "../interfaces/enumerable.h" #include "../interfaces/cmd_line_parser_option.h" #include #include namespace dlib { template < typename charT > class cmd_line_parser : public enumerable > { /*! REQUIREMENTS ON charT Must be an integral type suitable for storing characters. (e.g. char or wchar_t) INITIAL VALUE - parsed_line() == false - option_is_defined(x) == false, for all values of x - get_group_name() == "" ENUMERATION ORDER The enumerator will enumerate over all the options defined in *this in alphabetical order according to the name of the option. POINTERS AND REFERENCES TO INTERNAL DATA parsed_line(), option_is_defined(), option(), number_of_arguments(), operator[](), and swap() functions do not invalidate pointers or references to internal data. All other functions have no such guarantee. WHAT THIS OBJECT REPRESENTS This object represents a command line parser. The command lines must match the following BNF. command_line ::= { | } [ -- {} ] program_name ::= arg ::= any that does not start with - option_arg ::= option_name ::= long_option_name ::= { | - } options ::= - {} {} | -- [=] { } char ::= any character other than - or = word ::= any string from argv where argv is the second parameter to main() sword ::= any suffix of a string from argv where argv is the second parameter to main() bword ::= This is an empty string which denotes the beginning of a . Options with arguments: An option with N arguments will consider the next N swords to be its arguments. so for example, if we have an option o that expects 2 arguments then the following are a few legal examples: program -o arg1 arg2 general_argument program -oarg1 arg2 general_argument arg1 and arg2 are associated with the option o and general_argument is not. Arguments not associated with an option: An argument that is not associated with an option is considered a general command line argument and is indexed by operator[] defined by the cmd_line_parser object. Additionally, if the string "--" appears in the command line all by itself then all words following it are considered to be general command line arguments. Consider the following two examples involving a command line and a cmd_line_parser object called parser. Example 1: command line: program general_arg1 -o arg1 arg2 general_arg2 Then the following is true (assuming the o option is defined and takes 2 arguments). parser[0] == "general_arg1" parser[1] == "general_arg2" parser.number_of_arguments() == 2 parser.option("o").argument(0) == "arg1" parser.option("o").argument(1) == "arg2" parser.option("o").count() == 1 Example 2: command line: program general_arg1 -- -o arg1 arg2 general_arg2 Then the following is true (the -- causes everything following it to be treated as a general argument). parser[0] == "general_arg1" parser[1] == "-o" parser[2] == "arg1" parser[3] == "arg2" parser[4] == "general_arg2" parser.number_of_arguments() == 5 parser.option("o").count() == 0 !*/ public: typedef charT char_type; typedef std::basic_string string_type; typedef cmd_line_parser_option option_type; // exception class class cmd_line_parse_error : public dlib::error { /*! GENERAL This exception is thrown if there is an error detected in a command line while it is being parsed. You can consult this object's type and item members to determine the nature of the error. (note that the type member is inherited from dlib::error). INTERPRETING THIS EXCEPTION - if (type == EINVALID_OPTION) then - There was an undefined option on the command line - item == The invalid option that was on the command line - if (type == ETOO_FEW_ARGS) then - An option was given on the command line but it was not supplied with the required number of arguments. - item == The name of this option. - num == The number of arguments expected by this option. - if (type == ETOO_MANY_ARGS) then - An option was given on the command line such as --option=arg but this option doesn't take any arguments. - item == The name of this option. !*/ public: const std::basic_string item; const unsigned long num; }; // -------------------------- cmd_line_parser ( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc !*/ virtual ~cmd_line_parser ( ); /*! ensures - all memory associated with *this has been released !*/ void clear( ); /*! ensures - #*this has its initial value throws - std::bad_alloc if this exception is thrown then #*this is unusable until clear() is called and succeeds !*/ void parse ( int argc, const charT** argv ); /*! requires - argv == an array of strings that was obtained from the second argument of the function main(). (i.e. argv[0] should be the token, argv[1] should be an or token, etc.) - argc == the number of strings in argv ensures - parses the command line given by argc and argv - #parsed_line() == true - #at_start() == true throws - std::bad_alloc if this exception is thrown then #*this is unusable until clear() is called successfully - cmd_line_parse_error This exception is thrown if there is an error parsing the command line. If this exception is thrown then #parsed_line() == false and all options will have their count() set to 0 but otherwise there will be no effect (i.e. all registered options will remain registered). !*/ void parse ( int argc, charT** argv ); /*! This just calls this->parse(argc,argv) and performs the necessary const_cast on argv. !*/ bool parsed_line( ) const; /*! ensures - returns true if parse() has been called successfully - returns false otherwise !*/ bool option_is_defined ( const string_type& name ) const; /*! ensures - returns true if the option has been added to the parser object by calling add_option(name). - returns false otherwise !*/ void add_option ( const string_type& name, const string_type& description, unsigned long number_of_arguments = 0 ); /*! requires - parsed_line() == false - option_is_defined(name) == false - name does not contain any ' ', '\t', '\n', or '=' characters - name[0] != '-' - name.size() > 0 ensures - #option_is_defined(name) == true - #at_start() == true - #option(name).count() == 0 - #option(name).description() == description - #option(name).number_of_arguments() == number_of_arguments - #option(name).group_name() == get_group_name() throws - std::bad_alloc if this exception is thrown then the add_option() function has no effect !*/ const option_type& option ( const string_type& name ) const; /*! requires - option_is_defined(name) == true ensures - returns the option specified by name !*/ unsigned long number_of_arguments( ) const; /*! requires - parsed_line() == true ensures - returns the number of arguments present in the command line. This count does not include options or their arguments. Only arguments unrelated to any option are counted. !*/ const string_type& operator[] ( unsigned long N ) const; /*! requires - parsed_line() == true - N < number_of_arguments() ensures - returns the Nth command line argument !*/ void swap ( cmd_line_parser& item ); /*! ensures - swaps *this and item !*/ void print_options ( std::basic_ostream& out ) const; /*! ensures - prints all the command line options to out. - #at_start() == true throws - any exception. if an exception is thrown then #at_start() == true but otherwise it will have no effect on the state of #*this. !*/ void print_options ( ) const; /*! ensures - prints all the command line options to cout. - #at_start() == true throws - any exception. if an exception is thrown then #at_start() == true but otherwise it will have no effect on the state of #*this. !*/ string_type get_group_name ( ) const; /*! ensures - returns the current group name. This is the group new options will be added into when added via add_option(). - The group name of an option is used by print_options(). In particular, it groups all options with the same group name together and displays them under a title containing the text of the group name. This allows you to group similar options together in the output of print_options(). - A group name of "" (i.e. the empty string) means that no group name is set. !*/ void set_group_name ( const string_type& group_name ); /*! ensures - #get_group_name() == group_name !*/ // ------------------------------------------------------------- // Input Validation Tools // ------------------------------------------------------------- class cmd_line_check_error : public dlib::error { /*! This is the exception thrown by the check_*() routines if they find a command line error. The interpretation of the member variables is defined below in each check_*() routine. !*/ public: const string_type opt; const string_type opt2; const string_type arg; const std::vector required_opts; }; template < typename T > void check_option_arg_type ( const string_type& option_name ) const; /*! requires - parsed_line() == true - option_is_defined(option_name) == true - T is not a pointer type ensures - all the arguments for the given option are convertible by string_cast() to an object of type T. throws - std::bad_alloc - cmd_line_check_error This exception is thrown if the ensures clause could not be satisfied. The exception's members will be set as follows: - type == EINVALID_OPTION_ARG - opt == option_name - arg == the text of the offending argument !*/ template < typename T > void check_option_arg_range ( const string_type& option_name, const T& first, const T& last ) const; /*! requires - parsed_line() == true - option_is_defined(option_name) == true - first <= last - T is not a pointer type ensures - all the arguments for the given option are convertible by string_cast() to an object of type T and the resulting value is in the range first to last inclusive. throws - std::bad_alloc - cmd_line_check_error This exception is thrown if the ensures clause could not be satisfied. The exception's members will be set as follows: - type == EINVALID_OPTION_ARG - opt == option_name - arg == the text of the offending argument !*/ template < typename T, size_t length > void check_option_arg_range ( const string_type& option_name, const T (&arg_set)[length] ) const; /*! requires - parsed_line() == true - option_is_defined(option_name) == true - T is not a pointer type ensures - for each argument to the given option: - this argument is convertible by string_cast() to an object of type T and the resulting value is equal to some element in the arg_set array. throws - std::bad_alloc - cmd_line_check_error This exception is thrown if the ensures clause could not be satisfied. The exception's members will be set as follows: - type == EINVALID_OPTION_ARG - opt == option_name - arg == the text of the offending argument !*/ template < size_t length > void check_option_arg_range ( const string_type& option_name, const char_type* (&arg_set)[length] ) const; /*! requires - parsed_line() == true - option_is_defined(option_name) == true ensures - for each argument to the given option: - there is a string in the arg_set array that is equal to this argument. throws - std::bad_alloc - cmd_line_check_error This exception is thrown if the ensures clause could not be satisfied. The exception's members will be set as follows: - type == EINVALID_OPTION_ARG - opt == option_name - arg == the text of the offending argument !*/ template < size_t length > void check_one_time_options ( const char_type* (&option_set)[length] ) const; /*! requires - parsed_line() == true - for all valid i: - option_is_defined(option_set[i]) == true ensures - all the options in the option_set array occur at most once on the command line. throws - std::bad_alloc - cmd_line_check_error This exception is thrown if the ensures clause could not be satisfied. The exception's members will be set as follows: - type == EMULTIPLE_OCCURANCES - opt == the option that occurred more than once on the command line. !*/ void check_incompatible_options ( const string_type& option_name1, const string_type& option_name2 ) const; /*! requires - parsed_line() == true - option_is_defined(option_name1) == true - option_is_defined(option_name2) == true ensures - option(option_name1).count() == 0 || option(option_name2).count() == 0 (i.e. at most, only one of the options is currently present) throws - std::bad_alloc - cmd_line_check_error This exception is thrown if the ensures clause could not be satisfied. The exception's members will be set as follows: - type == EINCOMPATIBLE_OPTIONS - opt == option_name1 - opt2 == option_name2 !*/ template < size_t length > void check_incompatible_options ( const char_type* (&option_set)[length] ) const; /*! requires - parsed_line() == true - for all valid i: - option_is_defined(option_set[i]) == true ensures - At most only one of the options in the array option_set has a count() greater than 0. (i.e. at most, only one of the options is currently present) throws - std::bad_alloc - cmd_line_check_error This exception is thrown if the ensures clause could not be satisfied. The exception's members will be set as follows: - type == EINCOMPATIBLE_OPTIONS - opt == One of the incompatible options found. - opt2 == The next incompatible option found. !*/ void check_sub_option ( const string_type& parent_option, const string_type& sub_option ) const; /*! requires - parsed_line() == true - option_is_defined(parent_option) == true - option_is_defined(sub_option) == true ensures - if (option(parent_option).count() == 0) then - option(sub_option).count() == 0 throws - std::bad_alloc - cmd_line_check_error This exception is thrown if the ensures clause could not be satisfied. The exception's members will be set as follows: - type == EMISSING_REQUIRED_OPTION - opt == sub_option. - required_opts == a vector that contains only parent_option. !*/ template < size_t length > void check_sub_options ( const char_type* (&parent_option_set)[length], const string_type& sub_option ) const; /*! requires - parsed_line() == true - option_is_defined(sub_option) == true - for all valid i: - option_is_defined(parent_option_set[i] == true ensures - if (option(sub_option).count() > 0) then - At least one of the options in the array parent_option_set has a count() greater than 0. (i.e. at least one of the options in parent_option_set is currently present) throws - std::bad_alloc - cmd_line_check_error This exception is thrown if the ensures clause could not be satisfied. The exception's members will be set as follows: - type == EMISSING_REQUIRED_OPTION - opt == the first option from the sub_option that is present. - required_opts == a vector containing everything from parent_option_set. !*/ template < size_t length > void check_sub_options ( const string_type& parent_option, const char_type* (&sub_option_set)[length] ) const; /*! requires - parsed_line() == true - option_is_defined(parent_option) == true - for all valid i: - option_is_defined(sub_option_set[i]) == true ensures - if (option(parent_option).count() == 0) then - for all valid i: - option(sub_option_set[i]).count() == 0 throws - std::bad_alloc - cmd_line_check_error This exception is thrown if the ensures clause could not be satisfied. The exception's members will be set as follows: - type == EMISSING_REQUIRED_OPTION - opt == the first option from the sub_option_set that is present. - required_opts == a vector that contains only parent_option. !*/ template < size_t parent_length, size_t sub_length > void check_sub_options ( const char_type* (&parent_option_set)[parent_length], const char_type* (&sub_option_set)[sub_length] ) const; /*! requires - parsed_line() == true - for all valid i: - option_is_defined(parent_option_set[i] == true - for all valid j: - option_is_defined(sub_option_set[j]) == true ensures - for all valid j: - if (option(sub_option_set[j]).count() > 0) then - At least one of the options in the array parent_option_set has a count() greater than 0. (i.e. at least one of the options in parent_option_set is currently present) throws - std::bad_alloc - cmd_line_check_error This exception is thrown if the ensures clause could not be satisfied. The exception's members will be set as follows: - type == EMISSING_REQUIRED_OPTION - opt == the first option from the sub_option_set that is present. - required_opts == a vector containing everything from parent_option_set. !*/ private: // restricted functions cmd_line_parser(cmd_line_parser&); // copy constructor cmd_line_parser& operator=(cmd_line_parser&); // assignment operator }; // ----------------------------------------------------------------------------------------- typedef cmd_line_parser command_line_parser; typedef cmd_line_parser wcommand_line_parser; // ----------------------------------------------------------------------------------------- template < typename charT > inline void swap ( cmd_line_parser& a, cmd_line_parser& b ) { a.swap(b); } /*! provides a global swap function !*/ // ----------------------------------------------------------------------------------------- } #endif // DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_ ================================================ FILE: dlib/cmd_line_parser/cmd_line_parser_kernel_c.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CMD_LINE_PARSER_KERNEl_C_ #define DLIB_CMD_LINE_PARSER_KERNEl_C_ #include "cmd_line_parser_kernel_abstract.h" #include "../algs.h" #include "../assert.h" #include #include "../interfaces/cmd_line_parser_option.h" #include "../string.h" namespace dlib { template < typename clp_base > class cmd_line_parser_kernel_c : public clp_base { public: typedef typename clp_base::char_type char_type; typedef typename clp_base::string_type string_type; typedef typename clp_base::option_type option_type; void add_option ( const string_type& name, const string_type& description, unsigned long number_of_arguments = 0 ); const option_type& option ( const string_type& name ) const; unsigned long number_of_arguments( ) const; const option_type& element ( ) const; option_type& element ( ); const string_type& operator[] ( unsigned long N ) const; }; template < typename clp_base > inline void swap ( cmd_line_parser_kernel_c& a, cmd_line_parser_kernel_c& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename clp_base > const typename clp_base::string_type& cmd_line_parser_kernel_c:: operator[] ( unsigned long N ) const { // make sure requires clause is not broken DLIB_CASSERT( this->parsed_line() == true && N < number_of_arguments(), "\tvoid cmd_line_parser::operator[](unsigned long N)" << "\n\tYou must specify a valid index N and the parser must have run already." << "\n\tthis: " << this << "\n\tN: " << N << "\n\tparsed_line(): " << this->parsed_line() << "\n\tnumber_of_arguments(): " << number_of_arguments() ); return clp_base::operator[](N); } // ---------------------------------------------------------------------------------------- template < typename clp_base > void cmd_line_parser_kernel_c:: add_option ( const string_type& name, const string_type& description, unsigned long number_of_arguments ) { // make sure requires clause is not broken DLIB_CASSERT( this->parsed_line() == false && name.size() > 0 && this->option_is_defined(name) == false && name.find_first_of(_dT(char_type," \t\n=")) == string_type::npos && name[0] != '-', "\tvoid cmd_line_parser::add_option(const string_type&,const string_type&,unsigned long)" << "\n\tsee the requires clause of add_option()" << "\n\tthis: " << this << "\n\tname.size(): " << static_cast(name.size()) << "\n\tname: \"" << narrow(name) << "\"" << "\n\tparsed_line(): " << (this->parsed_line()? "true" : "false") << "\n\tis_option_defined(\"" << narrow(name) << "\"): " << (this->option_is_defined(name)? "true" : "false") ); clp_base::add_option(name,description,number_of_arguments); } // ---------------------------------------------------------------------------------------- template < typename clp_base > const typename clp_base::option_type& cmd_line_parser_kernel_c:: option ( const string_type& name ) const { // make sure requires clause is not broken DLIB_CASSERT( this->option_is_defined(name) == true, "\toption cmd_line_parser::option(const string_type&)" << "\n\tto get an option it must be defined by a call to add_option()" << "\n\tthis: " << this << "\n\tname: \"" << narrow(name) << "\"" ); return clp_base::option(name); } // ---------------------------------------------------------------------------------------- template < typename clp_base > unsigned long cmd_line_parser_kernel_c:: number_of_arguments( ) const { // make sure requires clause is not broken DLIB_CASSERT( this->parsed_line() == true , "\tunsigned long cmd_line_parser::number_of_arguments()" << "\n\tyou must parse the command line before you can find out how many arguments it has" << "\n\tthis: " << this ); return clp_base::number_of_arguments(); } // ---------------------------------------------------------------------------------------- template < typename clp_base > const typename clp_base::option_type& cmd_line_parser_kernel_c:: element ( ) const { // make sure requires clause is not broken DLIB_CASSERT(this->current_element_valid() == true, "\tconst cmd_line_parser_option& cmd_line_parser::element()" << "\n\tyou can't access the current element if it doesn't exist" << "\n\tthis: " << this ); // call the real function return clp_base::element(); } // ---------------------------------------------------------------------------------------- template < typename clp_base > typename clp_base::option_type& cmd_line_parser_kernel_c:: element ( ) { // make sure requires clause is not broken DLIB_CASSERT(this->current_element_valid() == true, "\tcmd_line_parser_option& cmd_line_parser::element()" << "\n\tyou can't access the current element if it doesn't exist" << "\n\tthis: " << this ); // call the real function return clp_base::element(); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CMD_LINE_PARSER_KERNEl_C_ ================================================ FILE: dlib/cmd_line_parser/cmd_line_parser_print_1.h ================================================ // Copyright (C) 2005 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CMD_LINE_PARSER_PRINt_1_ #define DLIB_CMD_LINE_PARSER_PRINt_1_ #include "cmd_line_parser_kernel_abstract.h" #include "../algs.h" #include "../string.h" #include #include #include #include #include namespace dlib { template < typename clp_base > class cmd_line_parser_print_1 : public clp_base { public: void print_options ( std::basic_ostream& out ) const; void print_options ( ) const { print_options(std::cout); } }; template < typename clp_base > inline void swap ( cmd_line_parser_print_1& a, cmd_line_parser_print_1& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename clp_base > void cmd_line_parser_print_1:: print_options ( std::basic_ostream& out ) const { typedef typename clp_base::char_type ct; typedef std::basic_string string; typedef typename string::size_type size_type; typedef std::basic_ostringstream ostringstream; try { size_type max_len = 0; this->reset(); // this loop here is just the bottom loop but without the print statements. // I'm doing this to figure out what len should be. while (this->move_next()) { size_type len = 0; len += 3; if (this->element().name().size() > 1) { ++len; } len += this->element().name().size(); if (this->element().number_of_arguments() == 1) { len += 6; } else { for (unsigned long i = 0; i < this->element().number_of_arguments(); ++i) { len += 7; if (i+1 > 9) ++len; } } len += 3; if (len < 33) max_len = std::max(max_len,len); } // Make a separate ostringstream for each option group. We are going to write // the output for each group to a separate ostringstream so that we can keep // them grouped together in the final output. std::map > groups; this->reset(); while(this->move_next()) { if (!groups[this->element().group_name()]) groups[this->element().group_name()].reset(new ostringstream); } this->reset(); while (this->move_next()) { ostringstream& sout = *groups[this->element().group_name()]; size_type len = 0; sout << _dT(ct,"\n -"); len += 3; if (this->element().name().size() > 1) { sout << _dT(ct,"-"); ++len; } sout << this->element().name(); len += this->element().name().size(); if (this->element().number_of_arguments() == 1) { sout << _dT(ct," "); len += 6; } else { for (unsigned long i = 0; i < this->element().number_of_arguments(); ++i) { sout << _dT(ct," "); len += 7; if (i+1 > 9) ++len; } } sout << _dT(ct," "); len += 3; while (len < max_len) { ++len; sout << _dT(ct," "); } const unsigned long ml = static_cast(max_len); // now print the description but make it wrap around nicely if it // is to long to fit on one line. if (len <= max_len) sout << wrap_string(this->element().description(),0,ml); else sout << _dT(ct,"\n") << wrap_string(this->element().description(),ml,ml); } // Only print out a generic Options: group name if there is an unnamed option // present. if (groups.count(string()) == 1) out << _dT(ct,"Options:"); // Now print everything out typename std::map >::iterator i; for (i = groups.begin(); i != groups.end(); ++i) { // print the group name if we have one if (i->first.size() != 0) { if (i != groups.begin()) out << _dT(ct,"\n\n"); out << i->first << _dT(ct,":"); } // print the options in the group out << i->second->str(); } out << _dT(ct,"\n\n"); this->reset(); } catch (...) { this->reset(); throw; } } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CMD_LINE_PARSER_PRINt_1_ ================================================ FILE: dlib/cmd_line_parser/get_option.h ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_GET_OPTiON_Hh_ #define DLIB_GET_OPTiON_Hh_ #include "get_option_abstract.h" #include "../string.h" #include "../is_kind.h" namespace dlib { // ---------------------------------------------------------------------------------------- class option_parse_error : public error { public: option_parse_error(const std::string& option_string, const std::string& str): error(EOPTION_PARSE,"Error parsing argument for option '" + option_string + "', offending string is '" + str + "'.") {} }; // ---------------------------------------------------------------------------------------- template T impl_config_reader_get_option ( const config_reader_type& cr, const std::string& option_name, const std::string& full_option_name, T default_value ) { std::string::size_type pos = option_name.find_first_of("."); if (pos == std::string::npos) { if (cr.is_key_defined(option_name)) { try{ return string_cast(cr[option_name]); } catch (string_cast_error&) { throw option_parse_error(full_option_name, cr[option_name]); } } } else { std::string block_name = option_name.substr(0,pos); if (cr.is_block_defined(block_name)) { return impl_config_reader_get_option(cr.block(block_name), option_name.substr(pos+1), full_option_name, default_value); } } return default_value; } // ---------------------------------------------------------------------------------------- template typename enable_if,T>::type get_option ( const cr_type& cr, const std::string& option_name, T default_value ) { return impl_config_reader_get_option(cr, option_name, option_name, default_value); } // ---------------------------------------------------------------------------------------- template typename disable_if,T>::type get_option ( const parser_type& parser, const std::string& option_name, T default_value ) { // make sure requires clause is not broken DLIB_ASSERT( parser.option_is_defined(option_name) == true && parser.option(option_name).number_of_arguments() == 1, "\t T get_option()" << "\n\t option_name: " << option_name << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name) << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments() ); if (parser.option(option_name)) { try { default_value = string_cast(parser.option(option_name).argument()); } catch (string_cast_error&) { throw option_parse_error(option_name, parser.option(option_name).argument()); } } return default_value; } // ---------------------------------------------------------------------------------------- template typename disable_if,T>::type get_option ( const parser_type& parser, const cr_type& cr, const std::string& option_name, T default_value ) { // make sure requires clause is not broken DLIB_ASSERT( parser.option_is_defined(option_name) == true && parser.option(option_name).number_of_arguments() == 1, "\t T get_option()" << "\n\t option_name: " << option_name << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name) << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments() ); if (parser.option(option_name)) return get_option(parser, option_name, default_value); else return get_option(cr, option_name, default_value); } // ---------------------------------------------------------------------------------------- template typename disable_if,T>::type get_option ( const cr_type& cr, const parser_type& parser, const std::string& option_name, T default_value ) { // make sure requires clause is not broken DLIB_ASSERT( parser.option_is_defined(option_name) == true && parser.option(option_name).number_of_arguments() == 1, "\t T get_option()" << "\n\t option_name: " << option_name << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name) << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments() ); if (parser.option(option_name)) return get_option(parser, option_name, default_value); else return get_option(cr, option_name, default_value); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template inline std::string get_option ( const T& cr, const std::string& option_name, const char* default_value ) { return get_option(cr, option_name, std::string(default_value)); } // ---------------------------------------------------------------------------------------- template inline std::string get_option ( const T& parser, const U& cr, const std::string& option_name, const char* default_value ) { return get_option(parser, cr, option_name, std::string(default_value)); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_GET_OPTiON_Hh_ ================================================ FILE: dlib/cmd_line_parser/get_option_abstract.h ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_GET_OPTiON_ABSTRACT_Hh_ #ifdef DLIB_GET_OPTiON_ABSTRACT_Hh_ #inclue namespace dlib { // ---------------------------------------------------------------------------------------- class option_parse_error : public error { /*! WHAT THIS OBJECT REPRESENTS This is the exception thrown by the get_option() functions. It is thrown when the option string given by a command line parser or config reader can't be converted into the type T. !*/ }; // ---------------------------------------------------------------------------------------- template < typename config_reader_type, typename T > T get_option ( const config_reader_type& cr, const std::string& option_name, T default_value ); /*! requires - T is a type which can be read from an input stream - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h ensures - option_name is used to index into the given config_reader. - if (cr contains an entry corresponding to option_name) then - converts the string value in cr corresponding to option_name into an object of type T and returns it. - else - returns default_value - The scheme for indexing into cr based on option_name is best understood by looking at a few examples: - an option name of "name" corresponds to cr["name"] - an option name of "block1.name" corresponds to cr.block("block1")["name"] - an option name of "block1.block2.name" corresponds to cr.block("block1").block("block2")["name"] throws - option_parse_error This exception is thrown if we attempt but fail to convert the string value in cr into an object of type T. !*/ // ---------------------------------------------------------------------------------------- template < typename command_line_parser_type, typename T > T get_option ( const command_line_parser_type& parser, const std::string& option_name, T default_value ); /*! requires - parser.option_is_defined(option_name) == true - parser.option(option_name).number_of_arguments() == 1 - T is a type which can be read from an input stream - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h ensures - if (parser.option(option_name)) then - converts parser.option(option_name).argument() into an object of type T and returns it. That is, the string argument to this command line option is converted into a T and returned. - else - returns default_value throws - option_parse_error This exception is thrown if we attempt but fail to convert the string argument into an object of type T. !*/ // ---------------------------------------------------------------------------------------- template < typename command_line_parser_type, typename config_reader_type, typename T > T get_option ( const command_line_parser_type& parser, const config_reader_type& cr, const std::string& option_name, T default_value ); /*! requires - parser.option_is_defined(option_name) == true - parser.option(option_name).number_of_arguments() == 1 - T is a type which can be read from an input stream - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h ensures - if (parser.option(option_name)) then - returns get_option(parser, option_name, default_value) - else - returns get_option(cr, option_name, default_value) !*/ // ---------------------------------------------------------------------------------------- template < typename command_line_parser_type, typename config_reader_type, typename T > T get_option ( const config_reader_type& cr, const command_line_parser_type& parser, const std::string& option_name, T default_value ); /*! requires - parser.option_is_defined(option_name) == true - parser.option(option_name).number_of_arguments() == 1 - T is a type which can be read from an input stream - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h ensures - if (parser.option(option_name)) then - returns get_option(parser, option_name, default_value) - else - returns get_option(cr, option_name, default_value) !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_GET_OPTiON_ABSTRACT_Hh_ ================================================ FILE: dlib/cmd_line_parser.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CMD_LINE_PARSEr_ #define DLIB_CMD_LINE_PARSEr_ #include "cmd_line_parser/cmd_line_parser_kernel_1.h" #include "cmd_line_parser/cmd_line_parser_kernel_c.h" #include "cmd_line_parser/cmd_line_parser_print_1.h" #include "cmd_line_parser/cmd_line_parser_check_1.h" #include "cmd_line_parser/cmd_line_parser_check_c.h" #include #include "cmd_line_parser/get_option.h" #include "map.h" #include "sequence.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename charT > class impl_cmd_line_parser { /*! This class is basically just a big templated typedef for building a complete command line parser type out of all the parts it needs. !*/ impl_cmd_line_parser() {} typedef typename sequence >::kernel_2a sequence_2a; typedef typename sequence*>::kernel_2a psequence_2a; typedef typename map,void*>::kernel_1a map_1a_string; public: typedef cmd_line_parser_kernel_1 kernel_1a; typedef cmd_line_parser_kernel_c kernel_1a_c; typedef cmd_line_parser_print_1 print_1a_c; typedef cmd_line_parser_check_c > check_1a_c; }; // ---------------------------------------------------------------------------------------- template < typename charT > class cmd_line_parser : public impl_cmd_line_parser::check_1a_c { public: // These typedefs are here for backwards compatibility with previous versions of dlib. typedef cmd_line_parser kernel_1a; typedef cmd_line_parser kernel_1a_c; typedef cmd_line_parser print_1a; typedef cmd_line_parser print_1a_c; typedef cmd_line_parser check_1a; typedef cmd_line_parser check_1a_c; }; template < typename charT > inline void swap ( cmd_line_parser& a, cmd_line_parser& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- typedef cmd_line_parser command_line_parser; typedef cmd_line_parser wcommand_line_parser; // ---------------------------------------------------------------------------------------- } #endif // DLIB_CMD_LINE_PARSEr_ ================================================ FILE: dlib/compress_stream/compress_stream_kernel_1.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_COMPRESS_STREAM_KERNEl_1_ #define DLIB_COMPRESS_STREAM_KERNEl_1_ #include "../algs.h" #include #include #include #include "compress_stream_kernel_abstract.h" namespace dlib { template < typename fce, typename fcd, typename crc32 > class compress_stream_kernel_1 { /*! REQUIREMENTS ON fce is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h the alphabet_size of fce must be 257. fce and fcd share the same kernel number. REQUIREMENTS ON fcd is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h the alphabet_size of fcd must be 257. fce and fcd share the same kernel number. REQUIREMENTS ON crc32 is an implementation of crc32/crc32_kernel_abstract.h INITIAL VALUE this object has no state CONVENTION this object has no state !*/ const static unsigned long eof_symbol = 256; public: class decompression_error : public dlib::error { public: decompression_error( const char* i ) : dlib::error(std::string(i)) {} decompression_error( const std::string& i ) : dlib::error(i) {} }; compress_stream_kernel_1 ( ) {} ~compress_stream_kernel_1 ( ) {} void compress ( std::istream& in, std::ostream& out ) const; void decompress ( std::istream& in, std::ostream& out ) const; private: // restricted functions compress_stream_kernel_1(compress_stream_kernel_1&); // copy constructor compress_stream_kernel_1& operator=(compress_stream_kernel_1&); // assignment operator }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename fce, typename fcd, typename crc32 > void compress_stream_kernel_1:: compress ( std::istream& in_, std::ostream& out_ ) const { std::streambuf::int_type temp; std::streambuf& in = *in_.rdbuf(); typename fce::entropy_encoder_type coder; coder.set_stream(out_); fce model(coder); crc32 crc; unsigned long count = 0; while (true) { // write out a known value every 20000 symbols if (count == 20000) { count = 0; coder.encode(1500,1501,8000); } ++count; // get the next character temp = in.sbumpc(); // if we have hit EOF then encode the marker symbol if (temp != EOF) { // encode the symbol model.encode(static_cast(temp)); crc.add(static_cast(temp)); continue; } else { model.encode(eof_symbol); // now write the checksum unsigned long checksum = crc.get_checksum(); unsigned char byte1 = static_cast((checksum>>24)&0xFF); unsigned char byte2 = static_cast((checksum>>16)&0xFF); unsigned char byte3 = static_cast((checksum>>8)&0xFF); unsigned char byte4 = static_cast((checksum)&0xFF); model.encode(byte1); model.encode(byte2); model.encode(byte3); model.encode(byte4); break; } } } // ---------------------------------------------------------------------------------------- template < typename fce, typename fcd, typename crc32 > void compress_stream_kernel_1:: decompress ( std::istream& in_, std::ostream& out_ ) const { std::streambuf& out = *out_.rdbuf(); typename fcd::entropy_decoder_type coder; coder.set_stream(in_); fcd model(coder); unsigned long symbol; unsigned long count = 0; crc32 crc; // decode until we hit the marker symbol while (true) { // make sure this is the value we expect if (count == 20000) { if (coder.get_target(8000) != 1500) { throw decompression_error("Error detected in compressed data stream."); } count = 0; coder.decode(1500,1501); } ++count; // decode the next symbol model.decode(symbol); if (symbol != eof_symbol) { crc.add(static_cast(symbol)); // write this symbol to out if (out.sputc(static_cast(symbol)) != static_cast(symbol)) { throw std::ios::failure("error occurred in compress_stream_kernel_1::decompress"); } continue; } else { // we read eof from the encoded data. now we just have to check the checksum and we are done. unsigned char byte1; unsigned char byte2; unsigned char byte3; unsigned char byte4; model.decode(symbol); byte1 = static_cast(symbol); model.decode(symbol); byte2 = static_cast(symbol); model.decode(symbol); byte3 = static_cast(symbol); model.decode(symbol); byte4 = static_cast(symbol); unsigned long checksum = byte1; checksum <<= 8; checksum |= byte2; checksum <<= 8; checksum |= byte3; checksum <<= 8; checksum |= byte4; if (checksum != crc.get_checksum()) throw decompression_error("Error detected in compressed data stream."); break; } } // while (true) } // ---------------------------------------------------------------------------------------- } #endif // DLIB_COMPRESS_STREAM_KERNEl_1_ ================================================ FILE: dlib/compress_stream/compress_stream_kernel_2.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_COMPRESS_STREAM_KERNEl_2_ #define DLIB_COMPRESS_STREAM_KERNEl_2_ #include "../algs.h" #include #include #include "compress_stream_kernel_abstract.h" namespace dlib { template < typename fce, typename fcd, typename lz77_buffer, typename sliding_buffer, typename fce_length, typename fcd_length, typename fce_index, typename fcd_index, typename crc32 > class compress_stream_kernel_2 { /*! REQUIREMENTS ON fce is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h the alphabet_size of fce must be 257. fce and fcd share the same kernel number. REQUIREMENTS ON fcd is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h the alphabet_size of fcd must be 257. fce and fcd share the same kernel number. REQUIREMENTS ON lz77_buffer is an implementation of lz77_buffer/lz77_buffer_kernel_abstract.h REQUIREMENTS ON sliding_buffer is an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h is instantiated with T = unsigned char REQUIREMENTS ON fce_length is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h the alphabet_size of fce must be 513. This will be used to encode the length of lz77 matches. fce_length and fcd share the same kernel number. REQUIREMENTS ON fcd_length is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h the alphabet_size of fcd must be 513. This will be used to decode the length of lz77 matches. fce_length and fcd share the same kernel number. REQUIREMENTS ON fce_index is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h the alphabet_size of fce must be 32257. This will be used to encode the index of lz77 matches. fce_index and fcd share the same kernel number. REQUIREMENTS ON fcd_index is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h the alphabet_size of fcd must be 32257. This will be used to decode the index of lz77 matches. fce_index and fcd share the same kernel number. REQUIREMENTS ON crc32 is an implementation of crc32/crc32_kernel_abstract.h INITIAL VALUE this object has no state CONVENTION this object has no state !*/ const static unsigned long eof_symbol = 256; public: class decompression_error : public dlib::error { public: decompression_error( const char* i ) : dlib::error(std::string(i)) {} decompression_error( const std::string& i ) : dlib::error(i) {} }; compress_stream_kernel_2 ( ) {} ~compress_stream_kernel_2 ( ) {} void compress ( std::istream& in, std::ostream& out ) const; void decompress ( std::istream& in, std::ostream& out ) const; private: // restricted functions compress_stream_kernel_2(compress_stream_kernel_2&); // copy constructor compress_stream_kernel_2& operator=(compress_stream_kernel_2&); // assignment operator }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename fce, typename fcd, typename lz77_buffer, typename sliding_buffer, typename fce_length, typename fcd_length, typename fce_index, typename fcd_index, typename crc32 > void compress_stream_kernel_2:: compress ( std::istream& in_, std::ostream& out_ ) const { std::streambuf::int_type temp; std::streambuf& in = *in_.rdbuf(); typename fce::entropy_encoder_type coder; coder.set_stream(out_); fce model(coder); fce_length model_length(coder); fce_index model_index(coder); const unsigned long LOOKAHEAD_LIMIT = 512; lz77_buffer buffer(15,LOOKAHEAD_LIMIT); crc32 crc; unsigned long count = 0; unsigned long lz77_count = 1; // number of times we used lz77 to encode unsigned long ppm_count = 1; // number of times we used ppm to encode while (true) { // write out a known value every 20000 symbols if (count == 20000) { count = 0; coder.encode(150,151,400); } ++count; // try to fill the lookahead buffer if (buffer.get_lookahead_buffer_size() < buffer.get_lookahead_buffer_limit()) { temp = in.sbumpc(); while (temp != EOF) { crc.add(static_cast(temp)); buffer.add(static_cast(temp)); if (buffer.get_lookahead_buffer_size() == buffer.get_lookahead_buffer_limit()) break; temp = in.sbumpc(); } } // compute the sum of ppm_count and lz77_count but make sure // it is less than 65536 unsigned long sum = ppm_count + lz77_count; if (sum >= 65536) { ppm_count >>= 1; lz77_count >>= 1; ppm_count |= 1; lz77_count |= 1; sum = ppm_count+lz77_count; } // if there are still more symbols in the lookahead buffer to encode if (buffer.get_lookahead_buffer_size() > 0) { unsigned long match_index, match_length; buffer.find_match(match_index,match_length,6); if (match_length != 0) { // signal the decoder that we are using lz77 coder.encode(0,lz77_count,sum); ++lz77_count; // encode the index and length pair model_index.encode(match_index); model_length.encode(match_length); } else { // signal the decoder that we are using ppm coder.encode(lz77_count,sum,sum); ++ppm_count; // encode the symbol using the ppm model model.encode(buffer.lookahead_buffer(0)); buffer.shift_buffers(1); } } else { // signal the decoder that we are using ppm coder.encode(lz77_count,sum,sum); model.encode(eof_symbol); // now write the checksum unsigned long checksum = crc.get_checksum(); unsigned char byte1 = static_cast((checksum>>24)&0xFF); unsigned char byte2 = static_cast((checksum>>16)&0xFF); unsigned char byte3 = static_cast((checksum>>8)&0xFF); unsigned char byte4 = static_cast((checksum)&0xFF); model.encode(byte1); model.encode(byte2); model.encode(byte3); model.encode(byte4); break; } } // while (true) } // ---------------------------------------------------------------------------------------- template < typename fce, typename fcd, typename lz77_buffer, typename sliding_buffer, typename fce_length, typename fcd_length, typename fce_index, typename fcd_index, typename crc32 > void compress_stream_kernel_2:: decompress ( std::istream& in_, std::ostream& out_ ) const { std::streambuf& out = *out_.rdbuf(); typename fcd::entropy_decoder_type coder; coder.set_stream(in_); fcd model(coder); fcd_length model_length(coder); fcd_index model_index(coder); unsigned long symbol; unsigned long count = 0; sliding_buffer buffer; buffer.set_size(15); // Initialize the buffer to all zeros. There is no algorithmic reason to // do this. But doing so avoids a warning from valgrind so that is why // I'm doing this. for (unsigned long i = 0; i < buffer.size(); ++i) buffer[i] = 0; crc32 crc; unsigned long lz77_count = 1; // number of times we used lz77 to encode unsigned long ppm_count = 1; // number of times we used ppm to encode bool next_block_lz77; // decode until we hit the marker symbol while (true) { // make sure this is the value we expect if (count == 20000) { if (coder.get_target(400) != 150) { throw decompression_error("Error detected in compressed data stream."); } count = 0; coder.decode(150,151); } ++count; // compute the sum of ppm_count and lz77_count but make sure // it is less than 65536 unsigned long sum = ppm_count + lz77_count; if (sum >= 65536) { ppm_count >>= 1; lz77_count >>= 1; ppm_count |= 1; lz77_count |= 1; sum = ppm_count+lz77_count; } // check if we are decoding a lz77 or ppm block if (coder.get_target(sum) < lz77_count) { coder.decode(0,lz77_count); next_block_lz77 = true; ++lz77_count; } else { coder.decode(lz77_count,sum); next_block_lz77 = false; ++ppm_count; } if (next_block_lz77) { unsigned long match_length, match_index; // decode the match index model_index.decode(match_index); // decode the match length model_length.decode(match_length); match_index += match_length; buffer.rotate_left(match_length); for (unsigned long i = 0; i < match_length; ++i) { unsigned char ch = buffer[match_index-i]; buffer[match_length-i-1] = ch; crc.add(ch); // write this ch to out if (out.sputc(static_cast(ch)) != static_cast(ch)) { throw std::ios::failure("error occurred in compress_stream_kernel_2::decompress"); } } } else { // decode the next symbol model.decode(symbol); if (symbol != eof_symbol) { buffer.rotate_left(1); buffer[0] = static_cast(symbol); crc.add(static_cast(symbol)); // write this symbol to out if (out.sputc(static_cast(symbol)) != static_cast(symbol)) { throw std::ios::failure("error occurred in compress_stream_kernel_2::decompress"); } } else { // this was the eof marker symbol so we are done. now check the checksum // now get the checksum and make sure it matches unsigned char byte1; unsigned char byte2; unsigned char byte3; unsigned char byte4; model.decode(symbol); byte1 = static_cast(symbol); model.decode(symbol); byte2 = static_cast(symbol); model.decode(symbol); byte3 = static_cast(symbol); model.decode(symbol); byte4 = static_cast(symbol); unsigned long checksum = byte1; checksum <<= 8; checksum |= byte2; checksum <<= 8; checksum |= byte3; checksum <<= 8; checksum |= byte4; if (checksum != crc.get_checksum()) throw decompression_error("Error detected in compressed data stream."); break; } } } // while (true) } // ---------------------------------------------------------------------------------------- } #endif // DLIB_COMPRESS_STREAM_KERNEl_2_ ================================================ FILE: dlib/compress_stream/compress_stream_kernel_3.h ================================================ // Copyright (C) 2005 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_COMPRESS_STREAM_KERNEl_3_ #define DLIB_COMPRESS_STREAM_KERNEl_3_ #include "../algs.h" #include "compress_stream_kernel_abstract.h" #include "../assert.h" namespace dlib { template < typename lzp_buf, typename crc32, unsigned long buffer_size > class compress_stream_kernel_3 { /*! REQUIREMENTS ON lzp_buf is an implementation of lzp_buffer/lzp_buffer_kernel_abstract.h REQUIREMENTS ON buffer_size 10 < buffer_size < 32 REQUIREMENTS ON crc32 is an implementation of crc32/crc32_kernel_abstract.h INITIAL VALUE this object has no state CONVENTION this object has no state This implementation uses the lzp_buffer and writes out matches in a byte aligned format. !*/ public: class decompression_error : public dlib::error { public: decompression_error( const char* i ) : dlib::error(std::string(i)) {} decompression_error( const std::string& i ) : dlib::error(i) {} }; compress_stream_kernel_3 ( ) { COMPILE_TIME_ASSERT(10 < buffer_size && buffer_size < 32); } ~compress_stream_kernel_3 ( ) {} void compress ( std::istream& in, std::ostream& out ) const; void decompress ( std::istream& in, std::ostream& out ) const; private: inline void write ( unsigned char symbol ) const { if (out->sputn(reinterpret_cast(&symbol),1)==0) throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3"); } inline void decode ( unsigned char& symbol, unsigned char& flag ) const { if (count == 0) { if (((size_t)in->sgetn(reinterpret_cast(buffer),sizeof(buffer)))!=sizeof(buffer)) throw decompression_error("Error detected in compressed data stream."); count = 8; } --count; symbol = buffer[8-count]; flag = buffer[0] >> 7; buffer[0] <<= 1; } inline void encode ( unsigned char symbol, unsigned char flag ) const /*! requires - 0 <= flag <= 1 ensures - writes symbol with the given one bit flag !*/ { // add this symbol and flag to the buffer ++count; buffer[0] <<= 1; buffer[count] = symbol; buffer[0] |= flag; if (count == 8) { if (((size_t)out->sputn(reinterpret_cast(buffer),sizeof(buffer)))!=sizeof(buffer)) throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3"); count = 0; buffer[0] = 0; } } void clear ( ) const /*! ensures - resets the buffers !*/ { count = 0; } void flush ( ) const /*! ensures - flushes any data in the buffers to out !*/ { if (count != 0) { buffer[0] <<= (8-count); if (((size_t)out->sputn(reinterpret_cast(buffer),sizeof(buffer)))!=sizeof(buffer)) throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3"); } } mutable unsigned int count; // count tells us how many bytes are buffered in buffer and how many flag // bit are currently in buffer[0] mutable unsigned char buffer[9]; // buffer[0] holds the flag bits to be writen. // the rest of the buffer holds the bytes to be writen. mutable std::streambuf* in; mutable std::streambuf* out; // restricted functions compress_stream_kernel_3(compress_stream_kernel_3&); // copy constructor compress_stream_kernel_3& operator=(compress_stream_kernel_3&); // assignment operator }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename lzp_buf, typename crc32, unsigned long buffer_size > void compress_stream_kernel_3:: compress ( std::istream& in_, std::ostream& out_ ) const { in = in_.rdbuf(); out = out_.rdbuf(); clear(); crc32 crc; lzp_buf buffer(buffer_size); std::streambuf::int_type temp = in->sbumpc(); unsigned long index; unsigned char symbol; unsigned char length; while (temp != EOF) { symbol = static_cast(temp); if (buffer.predict_match(index)) { if (buffer[index] == symbol) { // this is a match so we must find out how long it is length = 1; buffer.add(symbol); crc.add(symbol); temp = in->sbumpc(); while (length < 255) { if (temp == EOF) { break; } else if (static_cast(length) >= index) { break; } else if (static_cast(temp) == buffer[index]) { ++length; buffer.add(static_cast(temp)); crc.add(static_cast(temp)); temp = in->sbumpc(); } else { break; } } encode(length,1); } else { // this is also not a match encode(symbol,0); buffer.add(symbol); crc.add(symbol); // get the next symbol temp = in->sbumpc(); } } else { // there wasn't a match so just write this symbol encode(symbol,0); buffer.add(symbol); crc.add(symbol); // get the next symbol temp = in->sbumpc(); } } // use a match of zero length to indicate EOF encode(0,1); // now write the checksum unsigned long checksum = crc.get_checksum(); unsigned char byte1 = static_cast((checksum>>24)&0xFF); unsigned char byte2 = static_cast((checksum>>16)&0xFF); unsigned char byte3 = static_cast((checksum>>8)&0xFF); unsigned char byte4 = static_cast((checksum)&0xFF); encode(byte1,0); encode(byte2,0); encode(byte3,0); encode(byte4,0); flush(); } // ---------------------------------------------------------------------------------------- template < typename lzp_buf, typename crc32, unsigned long buffer_size > void compress_stream_kernel_3:: decompress ( std::istream& in_, std::ostream& out_ ) const { in = in_.rdbuf(); out = out_.rdbuf(); clear(); crc32 crc; lzp_buf buffer(buffer_size); unsigned long index = 0; unsigned char symbol; unsigned char length; unsigned char flag; decode(symbol,flag); while (flag == 0 || symbol != 0) { buffer.predict_match(index); if (flag == 1) { length = symbol; do { --length; symbol = buffer[index]; write(symbol); buffer.add(symbol); crc.add(symbol); } while (length != 0); } else { // this is just a literal write(symbol); buffer.add(symbol); crc.add(symbol); } decode(symbol,flag); } // now get the checksum and make sure it matches unsigned char byte1; unsigned char byte2; unsigned char byte3; unsigned char byte4; decode(byte1,flag); if (flag != 0) throw decompression_error("Error detected in compressed data stream."); decode(byte2,flag); if (flag != 0) throw decompression_error("Error detected in compressed data stream."); decode(byte3,flag); if (flag != 0) throw decompression_error("Error detected in compressed data stream."); decode(byte4,flag); if (flag != 0) throw decompression_error("Error detected in compressed data stream."); unsigned long checksum = byte1; checksum <<= 8; checksum |= byte2; checksum <<= 8; checksum |= byte3; checksum <<= 8; checksum |= byte4; if (checksum != crc.get_checksum()) throw decompression_error("Error detected in compressed data stream."); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_COMPRESS_STREAM_KERNEl_3_ ================================================ FILE: dlib/compress_stream/compress_stream_kernel_abstract.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_ #ifdef DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_ #include "../algs.h" #include namespace dlib { class compress_stream { /*! INITIAL VALUE This object does not have any state associated with it. WHAT THIS OBJECT REPRESENTS This object consists of the two functions compress and decompress. These functions allow you to compress and decompress data. !*/ public: class decompression_error : public dlib::error {}; compress_stream ( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc !*/ virtual ~compress_stream ( ); /*! ensures - all memory associated with *this has been released !*/ void compress ( std::istream& in, std::ostream& out ) const; /*! ensures - reads all data from in (until EOF is reached) and compresses it and writes it to out throws - std::ios_base::failure if there was a problem writing to out then this exception will be thrown. - any other exception this exception may be thrown if there is any other problem !*/ void decompress ( std::istream& in, std::ostream& out ) const; /*! ensures - reads data from in, decompresses it and writes it to out. note that it stops reading data from in when it encounters the end of the compressed data, not when it encounters EOF. throws - std::ios_base::failure if there was a problem writing to out then this exception will be thrown. - decompression_error if an error was detected in the compressed data that prevented it from being correctly decompressed then this exception is thrown. - any other exception this exception may be thrown if there is any other problem !*/ private: // restricted functions compress_stream(compress_stream&); // copy constructor compress_stream& operator=(compress_stream&); // assignment operator }; } #endif // DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_ ================================================ FILE: dlib/compress_stream.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_COMPRESS_STREAm_ #define DLIB_COMPRESS_STREAm_ #include "compress_stream/compress_stream_kernel_1.h" #include "compress_stream/compress_stream_kernel_2.h" #include "compress_stream/compress_stream_kernel_3.h" #include "conditioning_class.h" #include "entropy_encoder.h" #include "entropy_decoder.h" #include "entropy_encoder_model.h" #include "entropy_decoder_model.h" #include "lz77_buffer.h" #include "sliding_buffer.h" #include "lzp_buffer.h" #include "crc32.h" namespace dlib { class compress_stream { compress_stream() {} typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_1b fce1; typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_1b fcd1; typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_2b fce2; typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_2b fcd2; typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_3b fce3; typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_3b fcd3; typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_4a fce4a; typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_4a fcd4a; typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_4b fce4b; typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_4b fcd4b; typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5a fce5a; typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5a fcd5a; typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5b fce5b; typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5b fcd5b; typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5c fce5c; typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5c fcd5c; typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_6a fce6; typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_6a fcd6; typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_2d fce2d; typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_2d fcd2d; typedef sliding_buffer::kernel_1a sliding_buffer1; typedef lz77_buffer::kernel_2a lz77_buffer2a; typedef lzp_buffer::kernel_1a lzp_buf_1; typedef lzp_buffer::kernel_2a lzp_buf_2; typedef entropy_encoder_model<513,entropy_encoder::kernel_2a>::kernel_1b fce_length; typedef entropy_decoder_model<513,entropy_decoder::kernel_2a>::kernel_1b fcd_length; typedef entropy_encoder_model<65534,entropy_encoder::kernel_2a>::kernel_1b fce_length_2; typedef entropy_decoder_model<65534,entropy_decoder::kernel_2a>::kernel_1b fcd_length_2; typedef entropy_encoder_model<32257,entropy_encoder::kernel_2a>::kernel_1b fce_index; typedef entropy_decoder_model<32257,entropy_decoder::kernel_2a>::kernel_1b fcd_index; public: //----------- kernels --------------- // kernel_1a typedef compress_stream_kernel_1 kernel_1a; // kernel_1b typedef compress_stream_kernel_1 kernel_1b; // kernel_1c typedef compress_stream_kernel_1 kernel_1c; // kernel_1da typedef compress_stream_kernel_1 kernel_1da; // kernel_1ea typedef compress_stream_kernel_1 kernel_1ea; // kernel_1db typedef compress_stream_kernel_1 kernel_1db; // kernel_1eb typedef compress_stream_kernel_1 kernel_1eb; // kernel_1ec typedef compress_stream_kernel_1 kernel_1ec; // kernel_2a typedef compress_stream_kernel_2 kernel_2a; // kernel_3a typedef compress_stream_kernel_3 kernel_3a; // kernel_3b typedef compress_stream_kernel_3 kernel_3b; }; } #endif // DLIB_COMPRESS_STREAm_ ================================================ FILE: dlib/conditioning_class/conditioning_class_kernel_1.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CONDITIONING_CLASS_KERNEl_1_ #define DLIB_CONDITIONING_CLASS_KERNEl_1_ #include "conditioning_class_kernel_abstract.h" #include "../assert.h" #include "../algs.h" namespace dlib { template < unsigned long alphabet_size > class conditioning_class_kernel_1 { /*! INITIAL VALUE total == 1 counts == pointer to an array of alphabet_size unsigned shorts for all i except i == alphabet_size-1: counts[i] == 0 counts[alphabet_size-1] == 1 CONVENTION counts == pointer to an array of alphabet_size unsigned shorts get_total() == total get_count(symbol) == counts[symbol] LOW_COUNT(symbol) == sum of counts[0] though counts[symbol-1] or 0 if symbol == 0 get_memory_usage() == global_state.memory_usage !*/ public: class global_state_type { public: global_state_type () : memory_usage(0) {} private: unsigned long memory_usage; friend class conditioning_class_kernel_1; }; conditioning_class_kernel_1 ( global_state_type& global_state_ ); ~conditioning_class_kernel_1 ( ); void clear( ); bool increment_count ( unsigned long symbol, unsigned short amount = 1 ); unsigned long get_count ( unsigned long symbol ) const; unsigned long get_total ( ) const; unsigned long get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const; void get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const; unsigned long get_memory_usage ( ) const; global_state_type& get_global_state ( ); static unsigned long get_alphabet_size ( ); private: // restricted functions conditioning_class_kernel_1(conditioning_class_kernel_1&); // copy constructor conditioning_class_kernel_1& operator=(conditioning_class_kernel_1&); // assignment operator // data members unsigned short total; unsigned short* counts; global_state_type& global_state; }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > conditioning_class_kernel_1:: conditioning_class_kernel_1 ( global_state_type& global_state_ ) : total(1), counts(new unsigned short[alphabet_size]), global_state(global_state_) { COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); unsigned short* start = counts; unsigned short* end = counts+alphabet_size-1; while (start != end) { *start = 0; ++start; } *start = 1; // update memory usage global_state.memory_usage += sizeof(unsigned short)*alphabet_size + sizeof(conditioning_class_kernel_1); } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > conditioning_class_kernel_1:: ~conditioning_class_kernel_1 ( ) { delete [] counts; // update memory usage global_state.memory_usage -= sizeof(unsigned short)*alphabet_size + sizeof(conditioning_class_kernel_1); } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > void conditioning_class_kernel_1:: clear( ) { total = 1; unsigned short* start = counts; unsigned short* end = counts+alphabet_size-1; while (start != end) { *start = 0; ++start; } *start = 1; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_1:: get_memory_usage( ) const { return global_state.memory_usage; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > typename conditioning_class_kernel_1::global_state_type& conditioning_class_kernel_1:: get_global_state( ) { return global_state; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > bool conditioning_class_kernel_1:: increment_count ( unsigned long symbol, unsigned short amount ) { // if we are going over a total of 65535 then scale down all counts by 2 if (static_cast(total)+static_cast(amount) >= 65536) { total = 0; unsigned short* start = counts; unsigned short* end = counts+alphabet_size; while (start != end) { *start >>= 1; total += *start; ++start; } // make sure it is at least one if (counts[alphabet_size-1]==0) { ++total; counts[alphabet_size-1] = 1; } } counts[symbol] += amount; total += amount; return true; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_1:: get_count ( unsigned long symbol ) const { return counts[symbol]; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_1:: get_alphabet_size ( ) { return alphabet_size; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_1:: get_total ( ) const { return total; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_1:: get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const { if (counts[symbol] == 0) return 0; total_count = total; const unsigned short* start = counts; const unsigned short* end = counts+symbol; unsigned short high_count_temp = *start; while (start != end) { ++start; high_count_temp += *start; } low_count = high_count_temp - *start; high_count = high_count_temp; return *start; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > void conditioning_class_kernel_1:: get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const { unsigned long high_count_temp = *counts; const unsigned short* start = counts; while (target >= high_count_temp) { ++start; high_count_temp += *start; } low_count = high_count_temp - *start; high_count = high_count_temp; symbol = static_cast(start-counts); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CONDITIONING_CLASS_KERNEl_1_ ================================================ FILE: dlib/conditioning_class/conditioning_class_kernel_2.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CONDITIONING_CLASS_KERNEl_2_ #define DLIB_CONDITIONING_CLASS_KERNEl_2_ #include "conditioning_class_kernel_abstract.h" #include "../assert.h" #include "../algs.h" namespace dlib { template < unsigned long alphabet_size > class conditioning_class_kernel_2 { /*! INITIAL VALUE total == 1 symbols == pointer to array of alphabet_size data structs for all i except i == alphabet_size-1: symbols[i].count == 0 symbols[i].left_count == 0 symbols[alphabet_size-1].count == 1 symbols[alpahbet_size-1].left_count == 0 CONVENTION symbols == pointer to array of alphabet_size data structs get_total() == total get_count(symbol) == symbols[symbol].count symbols is organized as a tree with symbols[0] as the root. the left subchild of symbols[i] is symbols[i*2+1] and the right subchild is symbols[i*2+2]. the partent of symbols[i] == symbols[(i-1)/2] symbols[i].left_count == the sum of the counts of all the symbols to the left of symbols[i] get_memory_usage() == global_state.memory_usage !*/ public: class global_state_type { public: global_state_type () : memory_usage(0) {} private: unsigned long memory_usage; friend class conditioning_class_kernel_2; }; conditioning_class_kernel_2 ( global_state_type& global_state_ ); ~conditioning_class_kernel_2 ( ); void clear( ); bool increment_count ( unsigned long symbol, unsigned short amount = 1 ); unsigned long get_count ( unsigned long symbol ) const; inline unsigned long get_total ( ) const; unsigned long get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const; void get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const; unsigned long get_memory_usage ( ) const; global_state_type& get_global_state ( ); static unsigned long get_alphabet_size ( ); private: // restricted functions conditioning_class_kernel_2(conditioning_class_kernel_2&); // copy constructor conditioning_class_kernel_2& operator=(conditioning_class_kernel_2&); // assignment operator // data members unsigned short total; struct data { unsigned short count; unsigned short left_count; }; data* symbols; global_state_type& global_state; }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > conditioning_class_kernel_2:: conditioning_class_kernel_2 ( global_state_type& global_state_ ) : total(1), symbols(new data[alphabet_size]), global_state(global_state_) { COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); data* start = symbols; data* end = symbols + alphabet_size-1; while (start != end) { start->count = 0; start->left_count = 0; ++start; } start->count = 1; start->left_count = 0; // update the left_counts for the symbol alphabet_size-1 unsigned short temp; unsigned long symbol = alphabet_size-1; while (symbol != 0) { // temp will be 1 if symbol is odd, 0 if it is even temp = static_cast(symbol&0x1); // set symbol to its parent symbol = (symbol-1)>>1; // note that all left subchidren are odd and also that // if symbol was a left subchild then we want to increment // its parents left_count if (temp) ++symbols[symbol].left_count; } global_state.memory_usage += sizeof(data)*alphabet_size + sizeof(conditioning_class_kernel_2); } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > conditioning_class_kernel_2:: ~conditioning_class_kernel_2 ( ) { delete [] symbols; global_state.memory_usage -= sizeof(data)*alphabet_size + sizeof(conditioning_class_kernel_2); } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > void conditioning_class_kernel_2:: clear( ) { data* start = symbols; data* end = symbols + alphabet_size-1; total = 1; while (start != end) { start->count = 0; start->left_count = 0; ++start; } start->count = 1; start->left_count = 0; // update the left_counts unsigned short temp; unsigned long symbol = alphabet_size-1; while (symbol != 0) { // temp will be 1 if symbol is odd, 0 if it is even temp = static_cast(symbol&0x1); // set symbol to its parent symbol = (symbol-1)>>1; // note that all left subchidren are odd and also that // if symbol was a left subchild then we want to increment // its parents left_count symbols[symbol].left_count += temp; } } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_2:: get_memory_usage( ) const { return global_state.memory_usage; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > typename conditioning_class_kernel_2::global_state_type& conditioning_class_kernel_2:: get_global_state( ) { return global_state; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > bool conditioning_class_kernel_2:: increment_count ( unsigned long symbol, unsigned short amount ) { // if we need to renormalize then do so if (static_cast(total)+static_cast(amount) >= 65536) { unsigned long s; unsigned short temp; for (unsigned short i = 0; i < alphabet_size-1; ++i) { s = i; // divide the count for this symbol by 2 symbols[i].count >>= 1; symbols[i].left_count = 0; // bubble this change up though the tree while (s != 0) { // temp will be 1 if symbol is odd, 0 if it is even temp = static_cast(s&0x1); // set s to its parent s = (s-1)>>1; // note that all left subchidren are odd and also that // if s was a left subchild then we want to increment // its parents left_count if (temp) symbols[s].left_count += symbols[i].count; } } // update symbols alphabet_size-1 { s = alphabet_size-1; // divide alphabet_size-1 symbol by 2 if it's > 1 if (symbols[alphabet_size-1].count > 1) symbols[alphabet_size-1].count >>= 1; // bubble this change up though the tree while (s != 0) { // temp will be 1 if symbol is odd, 0 if it is even temp = static_cast(s&0x1); // set s to its parent s = (s-1)>>1; // note that all left subchidren are odd and also that // if s was a left subchild then we want to increment // its parents left_count if (temp) symbols[s].left_count += symbols[alphabet_size-1].count; } } // calculate the new total total = 0; unsigned long m = 0; while (m < alphabet_size) { total += symbols[m].count + symbols[m].left_count; m = (m<<1) + 2; } } // increment the count for the specified symbol symbols[symbol].count += amount;; total += amount; unsigned short temp; while (symbol != 0) { // temp will be 1 if symbol is odd, 0 if it is even temp = static_cast(symbol&0x1); // set symbol to its parent symbol = (symbol-1)>>1; // note that all left subchidren are odd and also that // if symbol was a left subchild then we want to increment // its parents left_count if (temp) symbols[symbol].left_count += amount; } return true; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_2:: get_count ( unsigned long symbol ) const { return symbols[symbol].count; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_2:: get_alphabet_size ( ) { return alphabet_size; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_2:: get_total ( ) const { return total; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_2:: get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const { if (symbols[symbol].count == 0) return 0; unsigned long current = symbol; total_count = total; unsigned long high_count_temp = 0; bool came_from_right = true; while (true) { if (came_from_right) { high_count_temp += symbols[current].count + symbols[current].left_count; } // note that if current is even then it is a right child came_from_right = !(current&0x1); if (current == 0) break; // set current to its parent current = (current-1)>>1 ; } low_count = high_count_temp - symbols[symbol].count; high_count = high_count_temp; return symbols[symbol].count; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > void conditioning_class_kernel_2:: get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const { unsigned long current = 0; unsigned long low_count_temp = 0; while (true) { if (static_cast(target) < symbols[current].left_count) { // we should go left current = (current<<1) + 1; } else { target -= symbols[current].left_count; low_count_temp += symbols[current].left_count; if (static_cast(target) < symbols[current].count) { // we have found our target symbol = current; high_count = low_count_temp + symbols[current].count; low_count = low_count_temp; break; } else { // go right target -= symbols[current].count; low_count_temp += symbols[current].count; current = (current<<1) + 2; } } } } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CONDITIONING_CLASS_KERNEl_1_ ================================================ FILE: dlib/conditioning_class/conditioning_class_kernel_3.h ================================================ // Copyright (C) 2004 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CONDITIONING_CLASS_KERNEl_3_ #define DLIB_CONDITIONING_CLASS_KERNEl_3_ #include "conditioning_class_kernel_abstract.h" #include "../assert.h" #include "../algs.h" namespace dlib { template < unsigned long alphabet_size > class conditioning_class_kernel_3 { /*! INITIAL VALUE total == 1 counts == pointer to an array of alphabet_size data structs for all i except i == 0: counts[i].count == 0 counts[0].count == 1 counts[0].symbol == alphabet_size-1 for all i except i == alphabet_size-1: counts[i].present == false counts[alphabet_size-1].present == true CONVENTION counts == pointer to an array of alphabet_size data structs get_total() == total get_count(symbol) == counts[x].count where counts[x].symbol == symbol LOW_COUNT(symbol) == sum of counts[0].count though counts[x-1].count where counts[x].symbol == symbol if (counts[0].symbol == symbol) LOW_COUNT(symbol)==0 if (counts[i].count == 0) then counts[i].symbol == undefined value if (symbol has a nonzero count) then counts[symbol].present == true get_memory_usage() == global_state.memory_usage !*/ public: class global_state_type { public: global_state_type () : memory_usage(0) {} private: unsigned long memory_usage; friend class conditioning_class_kernel_3; }; conditioning_class_kernel_3 ( global_state_type& global_state_ ); ~conditioning_class_kernel_3 ( ); void clear( ); bool increment_count ( unsigned long symbol, unsigned short amount = 1 ); unsigned long get_count ( unsigned long symbol ) const; unsigned long get_total ( ) const; unsigned long get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const; void get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const; unsigned long get_memory_usage ( ) const; global_state_type& get_global_state ( ); static unsigned long get_alphabet_size ( ); private: // restricted functions conditioning_class_kernel_3(conditioning_class_kernel_3&); // copy constructor conditioning_class_kernel_3& operator=(conditioning_class_kernel_3&); // assignment operator struct data { unsigned short count; unsigned short symbol; bool present; }; // data members unsigned short total; data* counts; global_state_type& global_state; }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > conditioning_class_kernel_3:: conditioning_class_kernel_3 ( global_state_type& global_state_ ) : total(1), counts(new data[alphabet_size]), global_state(global_state_) { COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); data* start = counts; data* end = counts+alphabet_size; start->count = 1; start->symbol = alphabet_size-1; start->present = false; ++start; while (start != end) { start->count = 0; start->present = false; ++start; } counts[alphabet_size-1].present = true; // update memory usage global_state.memory_usage += sizeof(data)*alphabet_size + sizeof(conditioning_class_kernel_3); } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > conditioning_class_kernel_3:: ~conditioning_class_kernel_3 ( ) { delete [] counts; // update memory usage global_state.memory_usage -= sizeof(data)*alphabet_size + sizeof(conditioning_class_kernel_3); } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > void conditioning_class_kernel_3:: clear( ) { total = 1; data* start = counts; data* end = counts+alphabet_size; start->count = 1; start->symbol = alphabet_size-1; start->present = false; ++start; while (start != end) { start->count = 0; start->present = false; ++start; } counts[alphabet_size-1].present = true; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > typename conditioning_class_kernel_3::global_state_type& conditioning_class_kernel_3:: get_global_state( ) { return global_state; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_3:: get_memory_usage( ) const { return global_state.memory_usage; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > bool conditioning_class_kernel_3:: increment_count ( unsigned long symbol, unsigned short amount ) { // if we are going over a total of 65535 then scale down all counts by 2 if (static_cast(total)+static_cast(amount) >= 65536) { total = 0; data* start = counts; data* end = counts+alphabet_size; while (start != end) { if (start->count == 1) { if (start->symbol == alphabet_size-1) { // this symbol must never be zero so we will leave its count at 1 ++total; } else { start->count = 0; counts[start->symbol].present = false; } } else { start->count >>= 1; total += start->count; } ++start; } } data* start = counts; data* swap_spot = counts; if (counts[symbol].present) { while (true) { if (start->symbol == symbol && start->count!=0) { unsigned short temp = start->count + amount; start->symbol = swap_spot->symbol; start->count = swap_spot->count; swap_spot->symbol = static_cast(symbol); swap_spot->count = temp; break; } if ( (start->count) < (swap_spot->count)) { swap_spot = start; } ++start; } } else { counts[symbol].present = true; while (true) { if (start->count == 0) { start->symbol = swap_spot->symbol; start->count = swap_spot->count; swap_spot->symbol = static_cast(symbol); swap_spot->count = amount; break; } if ((start->count) < (swap_spot->count)) { swap_spot = start; } ++start; } } total += amount; return true; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_3:: get_count ( unsigned long symbol ) const { if (counts[symbol].present == false) return 0; data* start = counts; while (start->symbol != symbol) { ++start; } return start->count; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_3:: get_alphabet_size ( ) { return alphabet_size; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_3:: get_total ( ) const { return total; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > unsigned long conditioning_class_kernel_3:: get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const { if (counts[symbol].present == false) return 0; total_count = total; unsigned long low_count_temp = 0; data* start = counts; while (start->symbol != symbol) { low_count_temp += start->count; ++start; } low_count = low_count_temp; high_count = low_count_temp + start->count; return start->count; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size > void conditioning_class_kernel_3:: get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const { unsigned long high_count_temp = counts->count; const data* start = counts; while (target >= high_count_temp) { ++start; high_count_temp += start->count; } low_count = high_count_temp - start->count; high_count = high_count_temp; symbol = static_cast(start->symbol); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CONDITIONING_CLASS_KERNEl_3_ ================================================ FILE: dlib/conditioning_class/conditioning_class_kernel_4.h ================================================ // Copyright (C) 2004 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CONDITIONING_CLASS_KERNEl_4_ #define DLIB_CONDITIONING_CLASS_KERNEl_4_ #include "conditioning_class_kernel_abstract.h" #include "../assert.h" #include "../algs.h" namespace dlib { template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > class conditioning_class_kernel_4 { /*! REQUIREMENTS ON pool_size pool_size > 0 this will be the number of nodes contained in our memory pool REQUIREMENTS ON mem_manager mem_manager is an implementation of memory_manager/memory_manager_kernel_abstract.h INITIAL VALUE total == 1 escapes == 1 next == 0 CONVENTION get_total() == total get_count(alphabet_size-1) == escapes if (next != 0) then next == pointer to the start of a linked list and the linked list is terminated by a node with a next pointer of 0. get_count(symbol) == node::count for the node where node::symbol==symbol or 0 if no such node currently exists. if (there is a node for the symbol) then LOW_COUNT(symbol) == the sum of all node's counts in the linked list up to but not including the node for the symbol. get_memory_usage() == global_state.memory_usage !*/ struct node { unsigned short symbol; unsigned short count; node* next; }; public: class global_state_type { public: global_state_type ( ) : memory_usage(pool_size*sizeof(node)+sizeof(global_state_type)) {} private: unsigned long memory_usage; typename mem_manager::template rebind::other pool; friend class conditioning_class_kernel_4; }; conditioning_class_kernel_4 ( global_state_type& global_state_ ); ~conditioning_class_kernel_4 ( ); void clear( ); bool increment_count ( unsigned long symbol, unsigned short amount = 1 ); unsigned long get_count ( unsigned long symbol ) const; inline unsigned long get_total ( ) const; unsigned long get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const; void get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const; unsigned long get_memory_usage ( ) const; global_state_type& get_global_state ( ); static unsigned long get_alphabet_size ( ); private: void half_counts ( ); /*! ensures - divides all counts by 2 but ensures that escapes is always at least 1 !*/ // restricted functions conditioning_class_kernel_4(conditioning_class_kernel_4&); // copy constructor conditioning_class_kernel_4& operator=(conditioning_class_kernel_4&); // assignment operator // data members unsigned short total; unsigned short escapes; node* next; global_state_type& global_state; }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > conditioning_class_kernel_4:: conditioning_class_kernel_4 ( global_state_type& global_state_ ) : total(1), escapes(1), next(0), global_state(global_state_) { COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); // update memory usage global_state.memory_usage += sizeof(conditioning_class_kernel_4); } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > conditioning_class_kernel_4:: ~conditioning_class_kernel_4 ( ) { clear(); // update memory usage global_state.memory_usage -= sizeof(conditioning_class_kernel_4); } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > void conditioning_class_kernel_4:: clear( ) { total = 1; escapes = 1; while (next) { node* temp = next; next = next->next; global_state.pool.deallocate(temp); } } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > unsigned long conditioning_class_kernel_4:: get_memory_usage( ) const { return global_state.memory_usage; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > typename conditioning_class_kernel_4::global_state_type& conditioning_class_kernel_4:: get_global_state( ) { return global_state; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > bool conditioning_class_kernel_4:: increment_count ( unsigned long symbol, unsigned short amount ) { if (symbol == alphabet_size-1) { // make sure we won't cause any overflow if (total >= 65536 - amount ) half_counts(); escapes += amount; total += amount; return true; } // find the symbol and increment it or add a new node to the list if (next) { node* temp = next; node* previous = 0; while (true) { if (temp->symbol == static_cast(symbol)) { // make sure we won't cause any overflow if (total >= 65536 - amount ) half_counts(); // we have found the symbol total += amount; temp->count += amount; // if this node now has a count greater than its parent node if (previous && temp->count > previous->count) { // swap the nodes so that the nodes will be in semi-sorted order swap(temp->count,previous->count); swap(temp->symbol,previous->symbol); } return true; } else if (temp->next == 0) { // we did not find the symbol so try to add it to the list if (global_state.pool.get_number_of_allocations() < pool_size) { // make sure we won't cause any overflow if (total >= 65536 - amount ) half_counts(); node* t = global_state.pool.allocate(); t->next = 0; t->symbol = static_cast(symbol); t->count = amount; temp->next = t; total += amount; return true; } else { // no memory left return false; } } else if (temp->count == 0) { // remove nodes that have a zero count if (previous) { previous->next = temp->next; node* t = temp; temp = temp->next; global_state.pool.deallocate(t); } else { next = temp->next; node* t = temp; temp = temp->next; global_state.pool.deallocate(t); } } else { previous = temp; temp = temp->next; } } // while (true) } // if there aren't any nodes in the list yet then do this instead else { if (global_state.pool.get_number_of_allocations() < pool_size) { // make sure we won't cause any overflow if (total >= 65536 - amount ) half_counts(); next = global_state.pool.allocate(); next->next = 0; next->symbol = static_cast(symbol); next->count = amount; total += amount; return true; } else { // no memory left return false; } } } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > unsigned long conditioning_class_kernel_4:: get_count ( unsigned long symbol ) const { if (symbol == alphabet_size-1) { return escapes; } else { node* temp = next; while (temp) { if (temp->symbol == symbol) return temp->count; temp = temp->next; } return 0; } } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > unsigned long conditioning_class_kernel_4:: get_alphabet_size ( ) { return alphabet_size; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > unsigned long conditioning_class_kernel_4:: get_total ( ) const { return total; } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > unsigned long conditioning_class_kernel_4:: get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const { if (symbol != alphabet_size-1) { node* temp = next; unsigned long low = 0; while (temp) { if (temp->symbol == static_cast(symbol)) { high_count = temp->count + low; low_count = low; total_count = total; return temp->count; } low += temp->count; temp = temp->next; } return 0; } else { total_count = total; high_count = total; low_count = total-escapes; return escapes; } } // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > void conditioning_class_kernel_4:: get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const { node* temp = next; unsigned long high = 0; while (true) { if (temp != 0) { high += temp->count; if (target < high) { symbol = temp->symbol; high_count = high; low_count = high - temp->count; return; } temp = temp->next; } else { // this must be the escape symbol symbol = alphabet_size-1; low_count = total-escapes; high_count = total; return; } } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // private member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < unsigned long alphabet_size, unsigned long pool_size, typename mem_manager > void conditioning_class_kernel_4:: half_counts ( ) { total = 0; if (escapes > 1) escapes >>= 1; //divide all counts by 2 node* temp = next; while (temp) { temp->count >>= 1; total += temp->count; temp = temp->next; } total += escapes; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CONDITIONING_CLASS_KERNEl_4_ ================================================ FILE: dlib/conditioning_class/conditioning_class_kernel_abstract.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_ #ifdef DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_ #include "../algs.h" namespace dlib { template < unsigned long alphabet_size > class conditioning_class { /*! REQUIREMENTS ON alphabet_size 1 < alphabet_size < 65536 INITIAL VALUE get_total() == 1 get_count(X) == 0 : for all valid values of X except alphabet_size-1 get_count(alphabet_size-1) == 1 WHAT THIS OBJECT REPRESENTS This object represents a conditioning class used for arithmetic style compression. It maintains the cumulative counts which are needed by the entropy_coder and entropy_decoder objects. At any moment a conditioning_class object represents a set of alphabet_size symbols. Each symbol is associated with an integer called its count. All symbols start out with a count of zero except for alphabet_size-1. This last symbol will always have a count of at least one. It is intended to be used as an escape into a lower context when coding and so it must never have a zero probability or the decoder won't be able to identify the escape symbol. NOTATION: Let MAP(i) be a function which maps integers to symbols. MAP(i) is one to one and onto. Its domain is 1 to alphabet_size inclusive. Let RMAP(s) be the inverse of MAP(i). ( i.e. RMAP(MAP(i)) == i and MAP(RMAP(s)) == s ) Let COUNT(i) give the count for the symbol MAP(i). ( i.e. COUNT(i) == get_count(MAP(i)) ) Let LOW_COUNT(s) == the sum of COUNT(x) for x == 1 to x == RMAP(s)-1 (note that the sum of COUNT(x) for x == 1 to x == 0 is 0) Let HIGH_COUNT(s) == LOW_COUNT(s) + get_count(s) Basically what this is saying is just that you shoudln't assume you know what order the symbols are placed in when calculating the cumulative sums. The specific mapping provided by the MAP() function is unspecified. THREAD SAFETY This object can be used safely in a multithreaded program as long as the global state is not shared between conditioning classes which run on different threads. GLOBAL_STATE_TYPE The global_state_type obejct allows instances of the conditioning_class object to share any kind of global state the implementer desires. However, the global_state_type object exists primarily to facilitate the sharing of a memory pool between many instances of a conditioning_class object. But note that it is not required that there be any kind of memory pool at all, it is just a possibility. !*/ public: class global_state_type { global_state_type ( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc !*/ // my contents are implementation specific. }; conditioning_class ( global_state_type& global_state ); /*! ensures - #*this is properly initialized - &#get_global_state() == &global_state throws - std::bad_alloc !*/ ~conditioning_class ( ); /*! ensures - all memory associated with *this has been released !*/ void clear( ); /*! ensures - #*this has its initial value throws - std::bad_alloc !*/ bool increment_count ( unsigned long symbol, unsigned short amount = 1 ); /*! requires - 0 <= symbol < alphabet_size - 0 < amount < 32768 ensures - if (sufficient memory is available to complete this operation) then - returns true - if (get_total()+amount < 65536) then - #get_count(symbol) == get_count(symbol) + amount - else - #get_count(symbol) == get_count(symbol)/2 + amount - if (get_count(alphabet_size-1) == 1) then - #get_count(alphabet_size-1) == 1 - else - #get_count(alphabet_size-1) == get_count(alphabet_size-1)/2 - for all X where (X != symbol)&&(X != alpahbet_size-1): #get_count(X) == get_count(X)/2 - else - returns false !*/ unsigned long get_count ( unsigned long symbol ) const; /*! requires - 0 <= symbol < alphabet_size ensures - returns the count for the specified symbol !*/ unsigned long get_total ( ) const; /*! ensures - returns the sum of get_count(X) for all valid values of X (i.e. returns the sum of the counts for all the symbols) !*/ unsigned long get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const; /*! requires - 0 <= symbol < alphabet_size ensures - returns get_count(symbol) - if (get_count(symbol) != 0) then - #total_count == get_total() - #low_count == LOW_COUNT(symbol) - #high_count == HIGH_COUNT(symbol) - #low_count < #high_count <= #total_count !*/ void get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const; /*! requires - 0 <= target < get_total() ensures - LOW_COUNT(#symbol) <= target < HIGH_COUNT(#symbol) - #low_count == LOW_COUNT(#symbol) - #high_count == HIGH_COUNT(#symbol) - #low_count < #high_count <= get_total() !*/ global_state_type& get_global_state ( ); /*! ensures - returns a reference to the global state used by *this !*/ unsigned long get_memory_usage ( ) const; /*! ensures - returns the number of bytes of memory allocated by all conditioning_class objects that share the global state given by get_global_state() !*/ static unsigned long get_alphabet_size ( ); /*! ensures - returns alphabet_size !*/ private: // restricted functions conditioning_class(conditioning_class&); // copy constructor conditioning_class& operator=(conditioning_class&); // assignment operator }; } #endif // DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_ ================================================ FILE: dlib/conditioning_class/conditioning_class_kernel_c.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CONDITIONING_CLASS_KERNEl_C_ #define DLIB_CONDITIONING_CLASS_KERNEl_C_ #include "conditioning_class_kernel_abstract.h" #include "../algs.h" #include "../assert.h" #include namespace dlib { template < typename cc_base > class conditioning_class_kernel_c : public cc_base { const unsigned long alphabet_size; public: conditioning_class_kernel_c ( typename cc_base::global_state_type& global_state ) : cc_base(global_state),alphabet_size(cc_base::get_alphabet_size()) {} bool increment_count ( unsigned long symbol, unsigned short amount = 1 ); unsigned long get_count ( unsigned long symbol ) const; unsigned long get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const; void get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const; }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename cc_base > bool conditioning_class_kernel_c:: increment_count ( unsigned long symbol, unsigned short amount ) { // make sure requires clause is not broken DLIB_CASSERT(symbol < alphabet_size && 0 < amount && amount < 32768, "\tvoid conditioning_class::increment_count()" << "\n\tthe symbol must be in the range 0 to alphabet_size-1. and" << "\n\tamount must be in the range 1 to 32767" << "\n\talphabet_size: " << alphabet_size << "\n\tsymbol: " << symbol << "\n\tamount: " << amount << "\n\tthis: " << this ); // call the real function return cc_base::increment_count(symbol,amount); } // ---------------------------------------------------------------------------------------- template < typename cc_base > unsigned long conditioning_class_kernel_c:: get_count ( unsigned long symbol ) const { // make sure requires clause is not broken DLIB_CASSERT(symbol < alphabet_size, "\tvoid conditioning_class::get_count()" << "\n\tthe symbol must be in the range 0 to alphabet_size-1" << "\n\talphabet_size: " << alphabet_size << "\n\tsymbol: " << symbol << "\n\tthis: " << this ); // call the real function return cc_base::get_count(symbol); } // ---------------------------------------------------------------------------------------- template < typename cc_base > unsigned long conditioning_class_kernel_c:: get_range ( unsigned long symbol, unsigned long& low_count, unsigned long& high_count, unsigned long& total_count ) const { // make sure requires clause is not broken DLIB_CASSERT(symbol < alphabet_size, "\tvoid conditioning_class::get_range()" << "\n\tthe symbol must be in the range 0 to alphabet_size-1" << "\n\talphabet_size: " << alphabet_size << "\n\tsymbol: " << symbol << "\n\tthis: " << this ); // call the real function return cc_base::get_range(symbol,low_count,high_count,total_count); } // ---------------------------------------------------------------------------------------- template < typename cc_base > void conditioning_class_kernel_c:: get_symbol ( unsigned long target, unsigned long& symbol, unsigned long& low_count, unsigned long& high_count ) const { // make sure requires clause is not broken DLIB_CASSERT( target < this->get_total(), "\tvoid conditioning_class::get_symbol()" << "\n\tthe target must be in the range 0 to get_total()-1" << "\n\tget_total(): " << this->get_total() << "\n\ttarget: " << target << "\n\tthis: " << this ); // call the real function cc_base::get_symbol(target,symbol,low_count,high_count); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CONDITIONING_CLASS_KERNEl_C_ ================================================ FILE: dlib/conditioning_class.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CONDITIONING_CLASs_ #define DLIB_CONDITIONING_CLASs_ #include "conditioning_class/conditioning_class_kernel_1.h" #include "conditioning_class/conditioning_class_kernel_2.h" #include "conditioning_class/conditioning_class_kernel_3.h" #include "conditioning_class/conditioning_class_kernel_4.h" #include "conditioning_class/conditioning_class_kernel_c.h" #include "memory_manager.h" namespace dlib { template < unsigned long alphabet_size > class conditioning_class { conditioning_class() {} typedef memory_manager::kernel_2b mm; public: //----------- kernels --------------- // kernel_1a typedef conditioning_class_kernel_1 kernel_1a; typedef conditioning_class_kernel_c kernel_1a_c; // kernel_2a typedef conditioning_class_kernel_2 kernel_2a; typedef conditioning_class_kernel_c kernel_2a_c; // kernel_3a typedef conditioning_class_kernel_3 kernel_3a; typedef conditioning_class_kernel_c kernel_3a_c; // -------- kernel_4 --------- // kernel_4a typedef conditioning_class_kernel_4 kernel_4a; typedef conditioning_class_kernel_c kernel_4a_c; // kernel_4b typedef conditioning_class_kernel_4 kernel_4b; typedef conditioning_class_kernel_c kernel_4b_c; // kernel_4c typedef conditioning_class_kernel_4 kernel_4c; typedef conditioning_class_kernel_c kernel_4c_c; // kernel_4d typedef conditioning_class_kernel_4 kernel_4d; typedef conditioning_class_kernel_c kernel_4d_c; }; } #endif // DLIB_CONDITIONING_CLASS_ ================================================ FILE: dlib/config.h ================================================ // If you are compiling dlib as a shared library and installing it somewhere on your system // then it is important that any programs that use dlib agree on the state of the // DLIB_ASSERT statements (i.e. they are either always on or always off). Therefore, // uncomment one of the following lines to force all DLIB_ASSERTs to either always on or // always off. If you don't define one of these two macros then DLIB_ASSERT will toggle // automatically depending on the state of certain other macros, which is not what you want // when creating a shared library. //#define ENABLE_ASSERTS // asserts always enabled //#define DLIB_DISABLE_ASSERTS // asserts always disabled //#define DLIB_ISO_CPP_ONLY //#define DLIB_NO_GUI_SUPPORT //#define DLIB_ENABLE_STACK_TRACE // You should also consider telling dlib to link against libjpeg, libpng, libgif, fftw, CUDA, // and a BLAS and LAPACK library. To do this you need to uncomment the following #defines. // #define DLIB_JPEG_SUPPORT // #define DLIB_PNG_SUPPORT // #define DLIB_GIF_SUPPORT // #define DLIB_USE_FFTW // #define DLIB_USE_BLAS // #define DLIB_USE_LAPACK // #define DLIB_USE_CUDA // Define this so the code in dlib/test_for_odr_violations.h can detect ODR violations // related to users doing bad things with config.h #define DLIB_NOT_CONFIGURED ================================================ FILE: dlib/config.h.in ================================================ // If you are compiling dlib as a shared library and installing it somewhere on your system // then it is important that any programs that use dlib agree on the state of the // DLIB_ASSERT statements (i.e. they are either always on or always off). Therefore, // uncomment one of the following lines to force all DLIB_ASSERTs to either always on or // always off. If you don't define one of these two macros then DLIB_ASSERT will toggle // automatically depending on the state of certain other macros, which is not what you want // when creating a shared library. #cmakedefine ENABLE_ASSERTS // asserts always enabled #cmakedefine DLIB_DISABLE_ASSERTS // asserts always disabled #cmakedefine DLIB_ISO_CPP_ONLY #cmakedefine DLIB_NO_GUI_SUPPORT #cmakedefine DLIB_ENABLE_STACK_TRACE #cmakedefine LAPACK_FORCE_UNDERSCORE #cmakedefine LAPACK_FORCE_NOUNDERSCORE // You should also consider telling dlib to link against libjpeg, libpng, libgif, fftw, CUDA, // and a BLAS and LAPACK library. To do this you need to uncomment the following #defines. #cmakedefine DLIB_JPEG_SUPPORT #cmakedefine DLIB_PNG_SUPPORT #cmakedefine DLIB_GIF_SUPPORT #cmakedefine DLIB_WEBP_SUPPORT #cmakedefine DLIB_JXL_SUPPORT #cmakedefine DLIB_USE_FFTW #cmakedefine DLIB_USE_BLAS #cmakedefine DLIB_USE_LAPACK #cmakedefine DLIB_USE_CUDA #cmakedefine DLIB_USE_MKL_FFT #cmakedefine DLIB_USE_FFMPEG // This variable allows dlib/test_for_odr_violations.h to catch people who mistakenly use // headers from one version of dlib with a compiled dlib binary from a different dlib version. #cmakedefine DLIB_CHECK_FOR_VERSION_MISMATCH @DLIB_CHECK_FOR_VERSION_MISMATCH@ ================================================ FILE: dlib/config_reader/config_reader_kernel_1.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CONFIG_READER_KERNEl_1_ #define DLIB_CONFIG_READER_KERNEl_1_ #include "config_reader_kernel_abstract.h" #include #include #include #include #include "../algs.h" #include "../stl_checked/std_vector_c.h" #ifndef DLIB_ISO_CPP_ONLY #include "config_reader_thread_safe_1.h" #endif namespace dlib { template < typename map_string_string, typename map_string_void, typename tokenizer > class config_reader_kernel_1 { /*! REQUIREMENTS ON map_string_string is an implementation of map/map_kernel_abstract.h that maps std::string to std::string REQUIREMENTS ON map_string_void is an implementation of map/map_kernel_abstract.h that maps std::string to void* REQUIREMENTS ON tokenizer is an implementation of tokenizer/tokenizer_kernel_abstract.h CONVENTION key_table.is_in_domain(x) == is_key_defined(x) block_table.is_in_domain(x) == is_block_defined(x) key_table[x] == operator[](x) block_table[x] == (void*)&block(x) !*/ public: // These two typedefs are defined for backwards compatibility with older versions of dlib. typedef config_reader_kernel_1 kernel_1a; #ifndef DLIB_ISO_CPP_ONLY typedef config_reader_thread_safe_1< config_reader_kernel_1, map_string_void > thread_safe_1a; #endif // DLIB_ISO_CPP_ONLY config_reader_kernel_1(); class config_reader_error : public dlib::error { friend class config_reader_kernel_1; config_reader_error( unsigned long ln, bool r = false ) : dlib::error(ECONFIG_READER), line_number(ln), redefinition(r) { std::ostringstream sout; sout << "Error in config_reader while parsing at line number " << line_number << "."; if (redefinition) sout << "\nThe identifier on this line has already been defined in this scope."; const_cast(info) = sout.str(); } public: const unsigned long line_number; const bool redefinition; }; class file_not_found : public dlib::error { friend class config_reader_kernel_1; file_not_found( const std::string& file_name_ ) : dlib::error(ECONFIG_READER, "Error in config_reader, unable to open file " + file_name_), file_name(file_name_) {} ~file_not_found() noexcept {} public: const std::string file_name; }; class config_reader_access_error : public dlib::error { public: config_reader_access_error( const std::string& block_name_, const std::string& key_name_ ) : dlib::error(ECONFIG_READER), block_name(block_name_), key_name(key_name_) { std::ostringstream sout; sout << "Error in config_reader.\n"; if (block_name.size() > 0) sout << " A block with the name '" << block_name << "' was expected but not found."; else if (key_name.size() > 0) sout << " A key with the name '" << key_name << "' was expected but not found."; const_cast(info) = sout.str(); } ~config_reader_access_error() noexcept {} const std::string block_name; const std::string key_name; }; config_reader_kernel_1( const std::string& config_file ); config_reader_kernel_1( std::istream& in ); virtual ~config_reader_kernel_1( ); void clear ( ); void load_from ( std::istream& in ); void load_from ( const std::string& config_file ); bool is_key_defined ( const std::string& key ) const; bool is_block_defined ( const std::string& name ) const; typedef config_reader_kernel_1 this_type; const this_type& block ( const std::string& name ) const; const std::string& operator[] ( const std::string& key ) const; template < typename queue_of_strings > void get_keys ( queue_of_strings& keys ) const; template < typename alloc > void get_keys ( std::vector& keys ) const; template < typename alloc > void get_keys ( std_vector_c& keys ) const; template < typename queue_of_strings > void get_blocks ( queue_of_strings& blocks ) const; template < typename alloc > void get_blocks ( std::vector& blocks ) const; template < typename alloc > void get_blocks ( std_vector_c& blocks ) const; private: static void parse_config_file ( config_reader_kernel_1& cr, tokenizer& tok, unsigned long& line_number, const bool top_of_recursion = true ); /*! requires - line_number == 1 - cr == *this - top_of_recursion == true ensures - parses the data coming from tok and puts it into cr. throws - config_reader_error !*/ map_string_string key_table; map_string_void block_table; // restricted functions config_reader_kernel_1(config_reader_kernel_1&); config_reader_kernel_1& operator=(config_reader_kernel_1&); }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > config_reader_kernel_1:: config_reader_kernel_1( ) { } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > void config_reader_kernel_1:: clear( ) { // free all our blocks block_table.reset(); while (block_table.move_next()) { delete static_cast(block_table.element().value()); } block_table.clear(); key_table.clear(); } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > void config_reader_kernel_1:: load_from( std::istream& in ) { clear(); tokenizer tok; tok.set_stream(in); tok.set_identifier_token( tok.lowercase_letters() + tok.uppercase_letters(), tok.lowercase_letters() + tok.uppercase_letters() + tok.numbers() + "_-." ); unsigned long line_number = 1; try { parse_config_file(*this,tok,line_number); } catch (...) { clear(); throw; } } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > void config_reader_kernel_1:: load_from( const std::string& config_file ) { clear(); std::ifstream fin(config_file.c_str()); if (!fin) throw file_not_found(config_file); load_from(fin); } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > config_reader_kernel_1:: config_reader_kernel_1( std::istream& in ) { load_from(in); } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > config_reader_kernel_1:: config_reader_kernel_1( const std::string& config_file ) { load_from(config_file); } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > void config_reader_kernel_1:: parse_config_file( config_reader_kernel_1& cr, tokenizer& tok, unsigned long& line_number, const bool top_of_recursion ) { int type; std::string token; bool in_comment = false; bool seen_identifier = false; std::string identifier; while (true) { tok.get_token(type,token); // ignore white space if (type == tokenizer::WHITE_SPACE) continue; // basically ignore end of lines if (type == tokenizer::END_OF_LINE) { ++line_number; in_comment = false; continue; } // we are in a comment still so ignore this if (in_comment) continue; // if this is the start of a comment if (type == tokenizer::CHAR && token[0] == '#') { in_comment = true; continue; } // if this is the case then we have just finished parsing a block so we should // quit this function if ( (type == tokenizer::CHAR && token[0] == '}' && !top_of_recursion) || (type == tokenizer::END_OF_FILE && top_of_recursion) ) { break; } if (seen_identifier) { seen_identifier = false; // the next character should be either a '=' or a '{' if (type != tokenizer::CHAR || (token[0] != '=' && token[0] != '{')) throw config_reader_error(line_number); if (token[0] == '=') { // we should parse the value out now // first discard any white space if (tok.peek_type() == tokenizer::WHITE_SPACE) tok.get_token(type,token); std::string value; type = tok.peek_type(); token = tok.peek_token(); while (true) { if (type == tokenizer::END_OF_FILE || type == tokenizer::END_OF_LINE) break; if (type == tokenizer::CHAR && token[0] == '\\') { tok.get_token(type,token); if (tok.peek_type() == tokenizer::CHAR && tok.peek_token()[0] == '#') { tok.get_token(type,token); value += '#'; } else if (tok.peek_type() == tokenizer::CHAR && tok.peek_token()[0] == '}') { tok.get_token(type,token); value += '}'; } else { value += '\\'; } } else if (type == tokenizer::CHAR && (token[0] == '#' || token[0] == '}')) { break; } else { value += token; tok.get_token(type,token); } type = tok.peek_type(); token = tok.peek_token(); } // while(true) // strip of any tailing white space from value std::string::size_type pos = value.find_last_not_of(" \t\r\n"); if (pos == std::string::npos) value.clear(); else value.erase(pos+1); // make sure this key isn't already in the key_table if (cr.key_table.is_in_domain(identifier)) throw config_reader_error(line_number,true); // add this key/value pair to the key_table cr.key_table.add(identifier,value); } else // when token[0] == '{' { // make sure this identifier isn't already in the block_table if (cr.block_table.is_in_domain(identifier)) throw config_reader_error(line_number,true); config_reader_kernel_1* new_cr = new config_reader_kernel_1; void* vtemp = new_cr; try { cr.block_table.add(identifier,vtemp); } catch (...) { delete new_cr; throw; } // now parse this block parse_config_file(*new_cr,tok,line_number,false); } } else { // the next thing should be an identifier but if it isn't this is an error if (type != tokenizer::IDENTIFIER) throw config_reader_error(line_number); seen_identifier = true; identifier = token; } } // while (true) } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > config_reader_kernel_1:: ~config_reader_kernel_1( ) { clear(); } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > bool config_reader_kernel_1:: is_key_defined ( const std::string& key ) const { return key_table.is_in_domain(key); } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > bool config_reader_kernel_1:: is_block_defined ( const std::string& name ) const { return block_table.is_in_domain(name); } // ---------------------------------------------------------------------------------------- template < typename mss, typename msv, typename tokenizer > const config_reader_kernel_1& config_reader_kernel_1:: block ( const std::string& name ) const { if (is_block_defined(name) == false) { throw config_reader_access_error(name,""); } return *static_cast(block_table[name]); } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > const std::string& config_reader_kernel_1:: operator[] ( const std::string& key ) const { if (is_key_defined(key) == false) { throw config_reader_access_error("",key); } return key_table[key]; } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > template < typename queue_of_strings > void config_reader_kernel_1:: get_keys ( queue_of_strings& keys ) const { keys.clear(); key_table.reset(); std::string temp; while (key_table.move_next()) { temp = key_table.element().key(); keys.enqueue(temp); } } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > template < typename alloc > void config_reader_kernel_1:: get_keys ( std::vector& keys ) const { keys.clear(); key_table.reset(); while (key_table.move_next()) { keys.push_back(key_table.element().key()); } } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > template < typename alloc > void config_reader_kernel_1:: get_keys ( std_vector_c& keys ) const { keys.clear(); key_table.reset(); while (key_table.move_next()) { keys.push_back(key_table.element().key()); } } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > template < typename queue_of_strings > void config_reader_kernel_1:: get_blocks ( queue_of_strings& blocks ) const { blocks.clear(); block_table.reset(); std::string temp; while (block_table.move_next()) { temp = block_table.element().key(); blocks.enqueue(temp); } } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > template < typename alloc > void config_reader_kernel_1:: get_blocks ( std::vector& blocks ) const { blocks.clear(); block_table.reset(); while (block_table.move_next()) { blocks.push_back(block_table.element().key()); } } // ---------------------------------------------------------------------------------------- template < typename map_string_string, typename map_string_void, typename tokenizer > template < typename alloc > void config_reader_kernel_1:: get_blocks ( std_vector_c& blocks ) const { blocks.clear(); block_table.reset(); while (block_table.move_next()) { blocks.push_back(block_table.element().key()); } } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CONFIG_READER_KERNEl_1_ ================================================ FILE: dlib/config_reader/config_reader_kernel_abstract.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_CONFIG_READER_KERNEl_ABSTRACT_ #ifdef DLIB_CONFIG_READER_KERNEl_ABSTRACT_ #include #include namespace dlib { class config_reader { /*! INITIAL VALUE - there aren't any keys defined for this object - there aren't any blocks defined for this object POINTERS AND REFERENCES TO INTERNAL DATA The destructor, clear(), and load_from() invalidate pointers and references to internal data. All other functions are guaranteed to NOT invalidate pointers or references to internal data. WHAT THIS OBJECT REPRESENTS This object represents something which is intended to be used to read text configuration files that are defined by the following EBNF (with config_file as the starting symbol): config_file = block; block = { key_value_pair | sub_block }; key_value_pair = key_name, "=", value; sub_block = block_name, "{", block, "}"; key_name = identifier; block_name = identifier; value = matches any string of text that ends with a newline character, # or }. note that the trailing newline, # or } is not part of the value though. identifier = Any string that matches the following regular expression: [a-zA-Z][a-zA-Z0-9_-\.]* i.e. Any string that starts with a letter and then is continued with any number of letters, numbers, _ . or - characters. Whitespace and comments are ignored. A comment is text that starts with # (but not \# since the \ escapes the # so that you can have a # symbol in a value if you want) and ends in a new line. You can also escape a } (e.g. "\}") if you want to have one in a value. Note that in a value the leading and trailing white spaces are stripped off but any white space inside the value is preserved. Also note that all key_names and block_names within a block syntax group must be unique but don't have to be globally unique. I.e. different blocks can reuse names. EXAMPLE CONFIG FILES: Example 1: #comment. This line is ignored because it starts with # #here we have key1 which will have the value of "my value" key1 = my value another_key= another value # this is another key called "another_key" with # a value of "another value" # this key's value is the empty string. I.e. "" key2= Example 2: #this example illustrates the use of blocks some_key = blah blah # now here is a block our_block { # here we can define some keys and values that are local to this block. a_key = something foo = bar some_key = more stuff # note that it is ok to name our key this even though # there is a key called some_key above. This is because # we are doing so inside a different block } another_block { foo = bar2 } # this block has only one key and is all on a single line !*/ public: // exception classes class config_reader_error : public dlib::error { /*! GENERAL This exception is thrown if there is an error while parsing the config file. The type member of this exception will be set to ECONFIG_READER. INTERPRETING THIS EXCEPTION - line_number == the line number the parser was at when the error occurred. - if (redefinition) then - The key or block name on line line_number has already been defined in this scope which is an error. - else - Some other general syntax error was detected !*/ public: const unsigned long line_number; const bool redefinition; }; class file_not_found : public dlib::error { /*! GENERAL This exception is thrown if the config file can't be opened for some reason. The type member of this exception will be set to ECONFIG_READER. INTERPRETING THIS EXCEPTION - file_name == the name of the config file which we failed to open !*/ public: const std::string file_name; }; class config_reader_access_error : public dlib::error { /*! GENERAL This exception is thrown if you try to access a key or block that doesn't exist inside a config reader. The type member of this exception will be set to ECONFIG_READER. !*/ public: config_reader_access_error( const std::string& block_name_, const std::string& key_name_ ); /*! ensures - #block_name == block_name_ - #key_name == key_name_ !*/ const std::string block_name; const std::string key_name; }; // -------------------------- config_reader( ); /*! ensures - #*this is properly initialized - This object will not have any keys or blocks defined in it. throws - std::bad_alloc - config_reader_error !*/ config_reader( std::istream& in ); /*! ensures - #*this is properly initialized - reads the config file to parse from the given input stream, parses it and loads this object up with all the sub blocks and key/value pairs it finds. - before the load is performed, the previous state of the config file reader is erased. So after the load the config file reader will contain only information from the given config file. - This object will represent the top most block of the config file. throws - std::bad_alloc - config_reader_error !*/ config_reader( const std::string& config_file ); /*! ensures - #*this is properly initialized - parses the config file named by the config_file string. Specifically, parses it and loads this object up with all the sub blocks and key/value pairs it finds in the file. - before the load is performed, the previous state of the config file reader is erased. So after the load the config file reader will contain only information from the given config file. - This object will represent the top most block of the config file. throws - std::bad_alloc - config_reader_error - file_not_found !*/ virtual ~config_reader( ); /*! ensures - all memory associated with *this has been released !*/ void clear( ); /*! ensures - #*this has its initial value throws - std::bad_alloc If this exception is thrown then *this is unusable until clear() is called and succeeds !*/ void load_from ( std::istream& in ); /*! ensures - reads the config file to parse from the given input stream, parses it and loads this object up with all the sub blocks and key/value pairs it finds. - before the load is performed, the previous state of the config file reader is erased. So after the load the config file reader will contain only information from the given config file. - *this will represent the top most block of the config file contained in the input stream in. throws - std::bad_alloc If this exception is thrown then *this is unusable until clear() is called and succeeds - config_reader_error If this exception is thrown then this object will revert to its initial value. !*/ void load_from ( const std::string& config_file ); /*! ensures - parses the config file named by the config_file string. Specifically, parses it and loads this object up with all the sub blocks and key/value pairs it finds in the file. - before the load is performed, the previous state of the config file reader is erased. So after the load the config file reader will contain only information from the given config file. - This object will represent the top most block of the config file. throws - std::bad_alloc If this exception is thrown then *this is unusable until clear() is called and succeeds - config_reader_error If this exception is thrown then this object will revert to its initial value. - file_not_found If this exception is thrown then this object will revert to its initial value. !*/ bool is_key_defined ( const std::string& key_name ) const; /*! ensures - if (there is a key with the given name defined within this config_reader's block) then - returns true - else - returns false !*/ bool is_block_defined ( const std::string& block_name ) const; /*! ensures - if (there is a sub block with the given name defined within this config_reader's block) then - returns true - else - returns false !*/ typedef config_reader this_type; const this_type& block ( const std::string& block_name ) const; /*! ensures - if (is_block_defined(block_name) == true) then - returns a const reference to the config_reader that represents the given named sub block - else - throws config_reader_access_error throws - config_reader_access_error if this exception is thrown then its block_name field will be set to the given block_name string. !*/ const std::string& operator[] ( const std::string& key_name ) const; /*! ensures - if (is_key_defined(key_name) == true) then - returns a const reference to the value string associated with the given key in this config_reader's block. - else - throws config_reader_access_error throws - config_reader_access_error if this exception is thrown then its key_name field will be set to the given key_name string. !*/ template < typename queue_of_strings > void get_keys ( queue_of_strings& keys ) const; /*! requires - queue_of_strings is an implementation of queue/queue_kernel_abstract.h with T set to std::string, or std::vector, or dlib::std_vector_c ensures - #keys == a collection containing all the keys defined in this config_reader's block. (i.e. for all strings str in keys it is the case that is_key_defined(str) == true) !*/ template < typename queue_of_strings > void get_blocks ( queue_of_strings& blocks ) const; /*! requires - queue_of_strings is an implementation of queue/queue_kernel_abstract.h with T set to std::string, or std::vector, or dlib::std_vector_c ensures - #blocks == a collection containing the names of all the blocks defined in this config_reader's block. (i.e. for all strings str in blocks it is the case that is_block_defined(str) == true) !*/ private: // restricted functions config_reader(config_reader&); // copy constructor config_reader& operator=(config_reader&); // assignment operator }; } #endif // DLIB_CONFIG_READER_KERNEl_ABSTRACT_ ================================================ FILE: dlib/config_reader/config_reader_thread_safe_1.h ================================================ // Copyright (C) 2007 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CONFIG_READER_THREAD_SAFe_ #define DLIB_CONFIG_READER_THREAD_SAFe_ #include "config_reader_kernel_abstract.h" #include #include #include #include "../algs.h" #include "../interfaces/enumerable.h" #include "../threads.h" #include "config_reader_thread_safe_abstract.h" namespace dlib { template < typename config_reader_base, typename map_string_void > class config_reader_thread_safe_1 { /*! CONVENTION - get_mutex() == *m - *cr == the config reader being extended - block_table[x] == (void*)&block(x) - block_table.size() == the number of blocks in *cr - block_table[key] == a config_reader_thread_safe_1 that contains &cr.block(key) - if (own_pointers) then - this object owns the m and cr pointers and should delete them when destructed !*/ public: config_reader_thread_safe_1 ( const config_reader_base* base, rmutex* m_ ); config_reader_thread_safe_1(); typedef typename config_reader_base::config_reader_error config_reader_error; typedef typename config_reader_base::config_reader_access_error config_reader_access_error; config_reader_thread_safe_1( std::istream& in ); config_reader_thread_safe_1( const std::string& config_file ); virtual ~config_reader_thread_safe_1( ); void clear ( ); void load_from ( std::istream& in ); void load_from ( const std::string& config_file ); bool is_key_defined ( const std::string& key ) const; bool is_block_defined ( const std::string& name ) const; typedef config_reader_thread_safe_1 this_type; const this_type& block ( const std::string& name ) const; const std::string& operator[] ( const std::string& key ) const; template < typename queue_of_strings > void get_keys ( queue_of_strings& keys ) const; template < typename queue_of_strings > void get_blocks ( queue_of_strings& blocks ) const; inline const rmutex& get_mutex ( ) const; private: void fill_block_table ( ); /*! ensures - block_table.size() == the number of blocks in cr - block_table[key] == a config_reader_thread_safe_1 that contains &cr.block(key) !*/ rmutex* m; config_reader_base* cr; map_string_void block_table; const bool own_pointers; // restricted functions config_reader_thread_safe_1(config_reader_thread_safe_1&); config_reader_thread_safe_1& operator=(config_reader_thread_safe_1&); }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > config_reader_thread_safe_1:: config_reader_thread_safe_1( const config_reader_base* base, rmutex* m_ ) : m(m_), cr(const_cast(base)), own_pointers(false) { fill_block_table(); } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > config_reader_thread_safe_1:: config_reader_thread_safe_1( ) : m(0), cr(0), own_pointers(true) { try { m = new rmutex; cr = new config_reader_base; } catch (...) { if (m) delete m; if (cr) delete cr; throw; } } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > void config_reader_thread_safe_1:: clear( ) { auto_mutex M(*m); cr->clear(); fill_block_table(); } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > void config_reader_thread_safe_1:: load_from( std::istream& in ) { auto_mutex M(*m); cr->load_from(in); fill_block_table(); } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > void config_reader_thread_safe_1:: load_from( const std::string& config_file ) { auto_mutex M(*m); cr->load_from(config_file); fill_block_table(); } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > config_reader_thread_safe_1:: config_reader_thread_safe_1( std::istream& in ) : m(0), cr(0), own_pointers(true) { try { m = new rmutex; cr = new config_reader_base(in); fill_block_table(); } catch (...) { if (m) delete m; if (cr) delete cr; throw; } } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > config_reader_thread_safe_1:: config_reader_thread_safe_1( const std::string& config_file ) : m(0), cr(0), own_pointers(true) { try { m = new rmutex; cr = new config_reader_base(config_file); fill_block_table(); } catch (...) { if (m) delete m; if (cr) delete cr; throw; } } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > config_reader_thread_safe_1:: ~config_reader_thread_safe_1( ) { if (own_pointers) { delete m; delete cr; } // clear out the block table block_table.reset(); while (block_table.move_next()) { delete static_cast(block_table.element().value()); } block_table.clear(); } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > bool config_reader_thread_safe_1:: is_key_defined ( const std::string& key ) const { auto_mutex M(*m); return cr->is_key_defined(key); } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > bool config_reader_thread_safe_1:: is_block_defined ( const std::string& name ) const { auto_mutex M(*m); return cr->is_block_defined(name); } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > const config_reader_thread_safe_1& config_reader_thread_safe_1:: block ( const std::string& name ) const { auto_mutex M(*m); if (block_table.is_in_domain(name) == false) { throw config_reader_access_error(name,""); } return *static_cast(block_table[name]); } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > const std::string& config_reader_thread_safe_1:: operator[] ( const std::string& key ) const { auto_mutex M(*m); return (*cr)[key]; } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > template < typename queue_of_strings > void config_reader_thread_safe_1:: get_keys ( queue_of_strings& keys ) const { auto_mutex M(*m); cr->get_keys(keys); } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > template < typename queue_of_strings > void config_reader_thread_safe_1:: get_blocks ( queue_of_strings& blocks ) const { auto_mutex M(*m); cr->get_blocks(blocks); } // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > const rmutex& config_reader_thread_safe_1:: get_mutex ( ) const { return *m; } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // private member functions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename config_reader_base, typename map_string_void > void config_reader_thread_safe_1:: fill_block_table ( ) { // first empty out the block table block_table.reset(); while (block_table.move_next()) { delete static_cast(block_table.element().value()); } block_table.clear(); std::vector blocks; cr->get_blocks(blocks); // now fill the block table up to match what is in cr for (unsigned long i = 0; i < blocks.size(); ++i) { config_reader_thread_safe_1* block = new config_reader_thread_safe_1(&cr->block(blocks[i]),m); void* temp = block; block_table.add(blocks[i],temp); } } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CONFIG_READER_THREAD_SAFe_ ================================================ FILE: dlib/config_reader/config_reader_thread_safe_abstract.h ================================================ // Copyright (C) 2007 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_ #ifdef DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_ #include #include #include "config_reader_kernel_abstract.h" #include "../threads/threads_kernel_abstract.h" namespace dlib { class config_reader_thread_safe { /*! WHAT THIS EXTENSION DOES FOR config_reader This object extends a normal config_reader by simply wrapping all its member functions inside mutex locks to make it safe to use in a threaded program. So this object provides an interface identical to the one defined in the config_reader/config_reader_kernel_abstract.h file except that the rmutex returned by get_mutex() is always locked when this object's member functions are called. !*/ public: const rmutex& get_mutex ( ) const; /*! ensures - returns the rmutex used to make this object thread safe. i.e. returns the rmutex that is locked when this object's functions are called. !*/ }; } #endif // DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_ ================================================ FILE: dlib/config_reader.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CONFIG_READEr_ #define DLIB_CONFIG_READEr_ #include "config_reader/config_reader_kernel_1.h" #include "map.h" #include "tokenizer.h" #include "cmd_line_parser/get_option.h" #include "algs.h" #include "is_kind.h" namespace dlib { typedef config_reader_kernel_1< map::kernel_1b, map::kernel_1b, tokenizer::kernel_1a > config_reader; template <> struct is_config_reader { const static bool value = true; }; #ifndef DLIB_ISO_CPP_ONLY typedef config_reader_thread_safe_1< config_reader, map::kernel_1b > config_reader_thread_safe; template <> struct is_config_reader { const static bool value = true; }; #endif // DLIB_ISO_CPP_ONLY } #endif // DLIB_CONFIG_READEr_ ================================================ FILE: dlib/console_progress_indicator.h ================================================ // Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_ #define DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_ #include #include #include #include namespace dlib { // ---------------------------------------------------------------------------------------- class console_progress_indicator { /*! WHAT THIS OBJECT REPRESENTS This object is a tool for reporting how long a task will take to complete. For example, consider the following bit of code: console_progress_indicator pbar(100) for (int i = 1; i <= 100; ++i) { pbar.print_status(i); long_running_operation(); } The above code will print a message to the console each iteration which shows the current progress and how much time is remaining until the loop terminates. !*/ public: inline explicit console_progress_indicator ( double target_value ); /*! ensures - #target() == target_value !*/ inline void reset ( double target_value ); /*! ensures - #target() == target_value - performs the equivalent of: *this = console_progress_indicator(target_value) (i.e. resets this object with a new target value) !*/ inline double target ( ) const; /*! ensures - This object attempts to measure how much time is left until we reach a certain targeted value. This function returns that targeted value. !*/ inline bool print_status ( double cur, bool always_print = false, std::ostream& out = std::clog ); /*! ensures - print_status() assumes it is called with values which are linearly approaching target(). It will display the current progress and attempt to predict how much time is remaining until cur becomes equal to target(). - prints a status message to out which indicates how much more time is left until cur is equal to target() - if (always_print) then - This function prints to the screen each time it is called. - else - This function throttles the printing so that at most 1 message is printed each second. Note that it won't print anything to the screen until about one second has elapsed. This means that the first call to print_status() never prints to the screen. - This function returns true if it prints to the screen and false otherwise. !*/ inline void finish ( std::ostream& out = std::cout ) const; /*! ensures - This object prints the completed progress and the elapsed time to out. It is meant to be called after the loop we are tracking the progress of. !*/ private: double target_val; std::chrono::time_point start_time; double first_val; double seen_first_val; std::chrono::time_point last_time; }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // IMPLEMENTATION DETAILS // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- console_progress_indicator:: console_progress_indicator ( double target_value ) : target_val(target_value), start_time(std::chrono::steady_clock::now()), first_val(0), seen_first_val(false), last_time(std::chrono::steady_clock::now()) { } // ---------------------------------------------------------------------------------------- bool console_progress_indicator:: print_status ( double cur, bool always_print, std::ostream& out ) { const auto cur_time = std::chrono::steady_clock::now(); // if this is the first time print_status has been called // then collect some information and exit. We will print status // on the next call. if (!seen_first_val) { start_time = cur_time; last_time = cur_time; first_val = cur; seen_first_val = true; return false; } if ((cur_time - last_time) >= std::chrono::seconds(1) || always_print) { last_time = cur_time; const auto delta_t = cur_time - start_time; double delta_val = std::abs(cur - first_val); // don't do anything if cur is equal to first_val if (delta_val < std::numeric_limits::epsilon()) return false; const auto rem_time = delta_t / delta_val * std::abs(target_val - cur); const auto oldflags = out.flags(); out.setf(std::ios::fixed,std::ios::floatfield); std::streamsize ss; // adapt the precision based on whether the target val is an integer if (std::trunc(target_val) == target_val) ss = out.precision(0); else ss = out.precision(2); out << "Progress: " << cur << "/" << target_val; ss = out.precision(2); out << " (" << cur / target_val * 100. << "%). "; const auto hours = std::chrono::duration_cast(rem_time); const auto minutes = std::chrono::duration_cast(rem_time) - hours; const auto seconds = std::chrono::duration_cast(rem_time) - hours - minutes; out << "Time remaining: "; if (rem_time >= std::chrono::hours(1)) out << hours.count() << "h "; if (rem_time >= std::chrono::minutes(1)) out << minutes.count() << "min "; out << seconds.count() << "s. \r" << std::flush; // restore previous output flags and precision settings out.flags(oldflags); out.precision(ss); return true; } return false; } // ---------------------------------------------------------------------------------------- double console_progress_indicator:: target ( ) const { return target_val; } // ---------------------------------------------------------------------------------------- void console_progress_indicator:: reset ( double target_value ) { *this = console_progress_indicator(target_value); } // ---------------------------------------------------------------------------------------- void console_progress_indicator:: finish ( std::ostream& out ) const { const auto oldflags = out.flags(); out.setf(std::ios::fixed,std::ios::floatfield); std::streamsize ss; // adapt the precision based on whether the target val is an integer if (std::trunc(target_val) == target_val) ss = out.precision(0); else ss = out.precision(2); out << "Progress: " << target_val << "/" << target_val; out << " (100.00%). "; const auto delta_t = std::chrono::steady_clock::now() - start_time; const auto hours = std::chrono::duration_cast(delta_t); const auto minutes = std::chrono::duration_cast(delta_t) - hours; const auto seconds = std::chrono::duration_cast(delta_t) - hours - minutes; out << "Time elapsed: "; if (delta_t >= std::chrono::hours(1)) out << hours.count() << "h "; if (delta_t >= std::chrono::minutes(1)) out << minutes.count() << "min "; out << seconds.count() << "s. " << std::endl; // restore previous output flags and precision settings out.flags(oldflags); out.precision(ss); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_ ================================================ FILE: dlib/constexpr_if.h ================================================ // Copyright (C) 2022 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_IF_CONSTEXPR_H #define DLIB_IF_CONSTEXPR_H #include "overloaded.h" #include "type_traits.h" namespace dlib { // ---------------------------------------------------------------------------------------- namespace detail { const auto _ = [](auto&& arg) -> decltype(auto) { return std::forward(arg); }; } // ---------------------------------------------------------------------------------------- template constexpr auto bools(std::integral_constant...) /*! ensures - returns a type list of compile time booleans. !*/ { return types_...>{}; } using true_t = types_; using false_t = types_; // ---------------------------------------------------------------------------------------- template < typename... T, typename... Cases > constexpr decltype(auto) switch_( types_ /*meta_obj*/, Cases&&... cases ) /*! requires - meta_obj combines a set of initial types. These are used as compile-time initial conditions. - cases is a set of overload-able conditional branches. - at least one of the cases is callable given meta_obj. - each case statement has signature auto(types_<>..., auto _) where _ is an identity function with identical behaviour to std::identity. This is used to make each generic lambda artificially dependent on the function body. This allows semantic analysis of the lambdas to be performed AFTER the correct lambda is chosen depending on meta_obj. This is the crucial bit that makes switch_() behave in a similar way to "if constexpr()" in C++17. Make sure to use _ on one of the objects in the lambdas. ensures - calls the correct conditional branch. - the correct conditional branch is selected at compile-time. - Note, each branch can return different types, and the return type of the switch_() function is that of the compile-time selected branch. Here is an example: template auto perform_correct_action(T& obj) { return switch_( types_{}, [&](types_, auto _) { return _(obj).set_something_specific_to_A_and_return_something(); }, [&](types_, auto _) { return _(obj).set_something_specific_to_B_and_return_something(); }, [&](auto...) { // Default case statement. Do something sensible. return false; } ); } Here is another example: template auto transfer_state(T& a, T& b) { return switch_( bools(std::is_move_constructible{}, std::is_copy_constructible{}), [&](true_t, auto, auto _) { // T is both move-constructible. Copy semantics can be anything a = std::move(_(b)); return move_tag{}; // Just for fun, we return different types in each branch. }, [&](auto, true_t, auto _) { // T is copy-constructible, Move semantics can be anything. Though in this case, // if it had been move-constructible, the first branch would have been selected. // So in this case, it is not move-constructible. a = _(b); return copy_tag{}; }, [&](auto...) { // Default case statement return dont_care_tag{}; } ); } !*/ { return overloaded(std::forward(cases)...)(types_{}..., detail::_); } // ---------------------------------------------------------------------------------------- } #endif //DLIB_IF_CONSTEXPR_H ================================================ FILE: dlib/control/approximate_linear_models.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_APPROXIMATE_LINEAR_MODELS_Hh_ #define DLIB_APPROXIMATE_LINEAR_MODELS_Hh_ #include "approximate_linear_models_abstract.h" #include "../matrix.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename feature_extractor > struct process_sample { typedef feature_extractor feature_extractor_type; typedef typename feature_extractor::state_type state_type; typedef typename feature_extractor::action_type action_type; process_sample(){} process_sample( const state_type& s, const action_type& a, const state_type& n, const double& r ) : state(s), action(a), next_state(n), reward(r) {} state_type state; action_type action; state_type next_state; double reward; }; template < typename feature_extractor > void serialize (const process_sample& item, std::ostream& out) { serialize(item.state, out); serialize(item.action, out); serialize(item.next_state, out); serialize(item.reward, out); } template < typename feature_extractor > void deserialize (process_sample& item, std::istream& in) { deserialize(item.state, in); deserialize(item.action, in); deserialize(item.next_state, in); deserialize(item.reward, in); } // ---------------------------------------------------------------------------------------- template < typename feature_extractor > class policy { public: typedef feature_extractor feature_extractor_type; typedef typename feature_extractor::state_type state_type; typedef typename feature_extractor::action_type action_type; policy ( ) { w.set_size(fe.num_features()); w = 0; } policy ( const matrix& weights_, const feature_extractor& fe_ ) : w(weights_), fe(fe_) {} action_type operator() ( const state_type& state ) const { return fe.find_best_action(state,w); } const feature_extractor& get_feature_extractor ( ) const { return fe; } const matrix& get_weights ( ) const { return w; } private: matrix w; feature_extractor fe; }; template < typename feature_extractor > inline void serialize(const policy& item, std::ostream& out) { int version = 1; serialize(version, out); serialize(item.get_feature_extractor(), out); serialize(item.get_weights(), out); } template < typename feature_extractor > inline void deserialize(policy& item, std::istream& in) { int version = 0; deserialize(version, in); if (version != 1) throw serialization_error("Unexpected version found while deserializing dlib::policy object."); feature_extractor fe; matrix w; deserialize(fe, in); deserialize(w, in); item = policy(w,fe); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_APPROXIMATE_LINEAR_MODELS_Hh_ ================================================ FILE: dlib/control/approximate_linear_models_abstract.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_ #ifdef DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_ #include "../matrix.h" namespace dlib { // ---------------------------------------------------------------------------------------- struct example_feature_extractor { /*! WHAT THIS OBJECT REPRESENTS This object defines the interface a feature extractor must implement if it is to be used with the process_sample and policy objects defined at the bottom of this file. Moreover, it is meant to represent the core part of a model used in a reinforcement learning algorithm. In particular, this object models a Q(state,action) function where Q(state,action) == dot(w, PSI(state,action)) where PSI(state,action) is a feature vector and w is a parameter vector. Therefore, a feature extractor defines how the PSI(x,y) feature vector is calculated. It also defines the types used to represent the state and action objects. THREAD SAFETY Instances of this object are required to be threadsafe, that is, it should be safe for multiple threads to make concurrent calls to the member functions of this object. !*/ // The state and actions can be any types so long as you provide typedefs for them. typedef T state_type; typedef U action_type; // We can also say that the last element in the weight vector w must be 1. This // can be useful for including a prior into your model. const static bool force_last_weight_to_1 = false; example_feature_extractor( ); /*! ensures - this object is properly initialized. !*/ unsigned long num_features( ) const; /*! ensures - returns the dimensionality of the PSI() feature vector. !*/ action_type find_best_action ( const state_type& state, const matrix& w ) const; /*! ensures - returns the action A that maximizes Q(state,A) = dot(w,PSI(state,A)). That is, this function finds the best action to take in the given state when our model is parameterized by the given weight vector w. !*/ void get_features ( const state_type& state, const action_type& action, matrix& feats ) const; /*! ensures - #feats.size() == num_features() - #feats == PSI(state,action) !*/ }; // ---------------------------------------------------------------------------------------- template < typename feature_extractor > struct process_sample { /*! REQUIREMENTS ON feature_extractor feature_extractor should implement the example_feature_extractor interface defined at the top of this file. WHAT THIS OBJECT REPRESENTS This object holds a training sample for a reinforcement learning algorithm. In particular, it should be a sample from some process where the process was in state this->state, then took this->action action which resulted in receiving this->reward and ending up in the state this->next_state. !*/ typedef feature_extractor feature_extractor_type; typedef typename feature_extractor::state_type state_type; typedef typename feature_extractor::action_type action_type; process_sample(){} process_sample( const state_type& s, const action_type& a, const state_type& n, const double& r ) : state(s), action(a), next_state(n), reward(r) {} state_type state; action_type action; state_type next_state; double reward; }; template < typename feature_extractor > void serialize (const process_sample& item, std::ostream& out); template < typename feature_extractor > void deserialize (process_sample& item, std::istream& in); /*! provides serialization support. !*/ // ---------------------------------------------------------------------------------------- template < typename feature_extractor > class policy { /*! REQUIREMENTS ON feature_extractor feature_extractor should implement the example_feature_extractor interface defined at the top of this file. WHAT THIS OBJECT REPRESENTS This is a policy based on the supplied feature_extractor model. In particular, it maps from feature_extractor::state_type to the best action to take in that state. !*/ public: typedef feature_extractor feature_extractor_type; typedef typename feature_extractor::state_type state_type; typedef typename feature_extractor::action_type action_type; policy ( ); /*! ensures - #get_feature_extractor() == feature_extractor() (i.e. it will have its default value) - #get_weights().size() == #get_feature_extractor().num_features() - #get_weights() == 0 !*/ policy ( const matrix& weights, const feature_extractor& fe ); /*! requires - fe.num_features() == weights.size() ensures - #get_feature_extractor() == fe - #get_weights() == weights !*/ action_type operator() ( const state_type& state ) const; /*! ensures - returns get_feature_extractor().find_best_action(state,w); !*/ const feature_extractor& get_feature_extractor ( ) const; /*! ensures - returns the feature extractor used by this object !*/ const matrix& get_weights ( ) const; /*! ensures - returns the parameter vector (w) associated with this object. The length of the vector is get_feature_extractor().num_features(). !*/ }; template < typename feature_extractor > void serialize(const policy& item, std::ostream& out); template < typename feature_extractor > void deserialize(policy& item, std::istream& in); /*! provides serialization support. !*/ // ---------------------------------------------------------------------------------------- #endif // DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_ ================================================ FILE: dlib/control/lspi.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_LSPI_Hh_ #define DLIB_LSPI_Hh_ #include "lspi_abstract.h" #include "approximate_linear_models.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename feature_extractor > class lspi { public: typedef feature_extractor feature_extractor_type; typedef typename feature_extractor::state_type state_type; typedef typename feature_extractor::action_type action_type; explicit lspi( const feature_extractor& fe_ ) : fe(fe_) { init(); } lspi( ) { init(); } double get_discount ( ) const { return discount; } void set_discount ( double value ) { // make sure requires clause is not broken DLIB_ASSERT(0 < value && value <= 1, "\t void lspi::set_discount(value)" << "\n\t invalid inputs were given to this function" << "\n\t value: " << value ); discount = value; } const feature_extractor& get_feature_extractor ( ) const { return fe; } void be_verbose ( ) { verbose = true; } void be_quiet ( ) { verbose = false; } void set_epsilon ( double eps_ ) { // make sure requires clause is not broken DLIB_ASSERT(eps_ > 0, "\t void lspi::set_epsilon(eps_)" << "\n\t invalid inputs were given to this function" << "\n\t eps_: " << eps_ ); eps = eps_; } double get_epsilon ( ) const { return eps; } void set_lambda ( double lambda_ ) { // make sure requires clause is not broken DLIB_ASSERT(lambda_ >= 0, "\t void lspi::set_lambda(lambda_)" << "\n\t invalid inputs were given to this function" << "\n\t lambda_: " << lambda_ ); lambda = lambda_; } double get_lambda ( ) const { return lambda; } void set_max_iterations ( unsigned long max_iter ) { max_iterations = max_iter; } unsigned long get_max_iterations ( ) { return max_iterations; } template policy train ( const vector_type& samples ) const { // make sure requires clause is not broken DLIB_ASSERT(samples.size() > 0, "\t policy lspi::train(samples)" << "\n\t invalid inputs were given to this function" ); matrix w(fe.num_features()); w = 0; matrix prev_w, b, f1, f2; matrix A; double change; unsigned long iter = 0; do { A = identity_matrix(fe.num_features())*lambda; b = 0; for (unsigned long i = 0; i < samples.size(); ++i) { fe.get_features(samples[i].state, samples[i].action, f1); fe.get_features(samples[i].next_state, fe.find_best_action(samples[i].next_state,w), f2); A += f1*trans(f1 - discount*f2); b += f1*samples[i].reward; } prev_w = w; if (feature_extractor::force_last_weight_to_1) w = join_cols(pinv(colm(A,range(0,A.nc()-2)))*(b-colm(A,A.nc()-1)),mat(1.0)); else w = pinv(A)*b; change = length(w-prev_w); ++iter; if (verbose) std::cout << "iteration: " << iter << "\tchange: " << change << std::endl; } while(change > eps && iter < max_iterations); return policy(w,fe); } private: void init() { lambda = 0.01; discount = 0.8; eps = 0.01; verbose = false; max_iterations = 100; } double lambda; double discount; double eps; bool verbose; unsigned long max_iterations; feature_extractor fe; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_LSPI_Hh_ ================================================ FILE: dlib/control/lspi_abstract.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_LSPI_ABSTRACT_Hh_ #ifdef DLIB_LSPI_ABSTRACT_Hh_ #include "approximate_linear_models_abstract.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename feature_extractor > class lspi { /*! REQUIREMENTS ON feature_extractor feature_extractor should implement the example_feature_extractor interface defined at the top of dlib/control/approximate_linear_models_abstract.h WHAT THIS OBJECT REPRESENTS This object is an implementation of the reinforcement learning algorithm described in the following paper: Lagoudakis, Michail G., and Ronald Parr. "Least-squares policy iteration." The Journal of Machine Learning Research 4 (2003): 1107-1149. This means that it takes a bunch of training data in the form of process_samples and outputs a policy that hopefully performs well when run on the process that generated those samples. !*/ public: typedef feature_extractor feature_extractor_type; typedef typename feature_extractor::state_type state_type; typedef typename feature_extractor::action_type action_type; explicit lspi( const feature_extractor& fe_ ); /*! ensures - #get_feature_extractor() == fe_ - #get_lambda() == 0.01 - #get_discount == 0.8 - #get_epsilon() == 0.01 - is not verbose - #get_max_iterations() == 100 !*/ lspi( ); /*! ensures - #get_feature_extractor() == feature_extractor() (i.e. it will have its default value) - #get_lambda() == 0.01 - #get_discount == 0.8 - #get_epsilon() == 0.01 - is not verbose - #get_max_iterations() == 100 !*/ double get_discount ( ) const; /*! ensures - returns the discount applied to the sum of rewards in the Bellman equation. !*/ void set_discount ( double value ); /*! requires - 0 < value <= 1 ensures - #get_discount() == value !*/ const feature_extractor& get_feature_extractor ( ) const; /*! ensures - returns the feature extractor used by this object !*/ void be_verbose ( ); /*! ensures - This object will print status messages to standard out so that a user can observe the progress of the algorithm. !*/ void be_quiet ( ); /*! ensures - this object will not print anything to standard out !*/ void set_epsilon ( double eps ); /*! requires - eps > 0 ensures - #get_epsilon() == eps !*/ double get_epsilon ( ) const; /*! ensures - returns the error epsilon that determines when training should stop. Smaller values may result in a more accurate solution but take longer to train. !*/ void set_lambda ( double lambda_ ); /*! requires - lambda >= 0 ensures - #get_lambda() == lambda !*/ double get_lambda ( ) const; /*! ensures - returns the regularization parameter. It is the parameter that determines the trade off between trying to fit the training data exactly or allowing more errors but hopefully improving the generalization ability of the resulting function. Smaller values encourage exact fitting while larger values of lambda may encourage better generalization. !*/ void set_max_iterations ( unsigned long max_iter ); /*! ensures - #get_max_iterations() == max_iter !*/ unsigned long get_max_iterations ( ); /*! ensures - returns the maximum number of iterations the SVM optimizer is allowed to run before it is required to stop and return a result. !*/ template < typename vector_type > policy train ( const vector_type& samples ) const; /*! requires - samples.size() > 0 - samples is something with an interface that looks like std::vector>. That is, it should be some kind of array of process_sample objects. ensures - Trains a policy based on the given data and returns the results. The idea is to find a policy that will obtain the largest possible reward when run on the process that generated the samples. In particular, if the returned policy is P then: - P(S) == the best action to take when in state S. - if (feature_extractor::force_last_weight_to_1) then - The last element of P.get_weights() is 1. !*/ }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_LSPI_ABSTRACT_Hh_ ================================================ FILE: dlib/control/mpc.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_MPC_Hh_ #define DLIB_MPC_Hh_ #include "mpc_abstract.h" #include "../matrix.h" #include "../algs.h" namespace dlib { template < long S_, long I_, unsigned long horizon_ > class mpc { public: const static long S = S_; const static long I = I_; const static unsigned long horizon = horizon_; mpc( ) { A = 0; B = 0; C = 0; Q = 0; R = 0; lower = 0; upper = 0; max_iterations = 0; eps = 0.01; for (unsigned long i = 0; i < horizon; ++i) { target[i].set_size(A.nr()); target[i] = 0; controls[i].set_size(B.nc()); controls[i] = 0; } lambda = 0; } mpc ( const matrix& A_, const matrix& B_, const matrix& C_, const matrix& Q_, const matrix& R_, const matrix& lower_, const matrix& upper_ ) : A(A_), B(B_), C(C_), Q(Q_), R(R_), lower(lower_), upper(upper_) { // make sure requires clause is not broken DLIB_ASSERT(A.nr() > 0 && B.nc() > 0, "\t mpc::mpc()" << "\n\t invalid inputs were given to this function" << "\n\t A.nr(): " << A.nr() << "\n\t B.nc(): " << B.nc() ); DLIB_ASSERT(A.nr() == A.nc() && A.nr() == B.nr() && A.nr() == C.nr() && A.nr() == Q.nr(), "\t mpc::mpc()" << "\n\t invalid inputs were given to this function" << "\n\t A.nr(): " << A.nr() << "\n\t A.nc(): " << A.nc() << "\n\t B.nr(): " << B.nr() << "\n\t C.nr(): " << C.nr() << "\n\t Q.nr(): " << Q.nr() ); DLIB_ASSERT( B.nc() == R.nr() && B.nc() == lower.nr() && B.nc() == upper.nr() , "\t mpc::mpc()" << "\n\t invalid inputs were given to this function" << "\n\t B.nr(): " << B.nr() << "\n\t B.nc(): " << B.nc() << "\n\t lower.nr(): " << lower.nr() << "\n\t upper.nr(): " << upper.nr() ); DLIB_ASSERT(min(Q) >= 0 && min(R) > 0 && min(upper-lower) >= 0, "\t mpc::mpc()" << "\n\t invalid inputs were given to this function" << "\n\t min(Q): " << min(Q) << "\n\t min(R): " << min(R) << "\n\t min(upper-lower): " << min(upper-lower) ); max_iterations = 10000; eps = 0.01; for (unsigned long i = 0; i < horizon; ++i) { target[i].set_size(A.nr()); target[i] = 0; controls[i].set_size(B.nc()); controls[i] = 0; } // Bound the maximum eigenvalue of the hessian by computing the trace of the // hessian matrix. lambda = sum(R)*horizon; matrix temp = diagm(Q); for (unsigned long c = 0; c < horizon; ++c) { lambda += trace(trans(B)*temp*B); Q_diag[horizon-c-1] = diag(trans(B)*temp*B); temp = trans(A)*temp*A + diagm(Q); } } const matrix& get_A ( ) const { return A; } const matrix& get_B ( ) const { return B; } const matrix& get_C ( ) const { return C; } const matrix& get_Q ( ) const { return Q; } const matrix& get_R ( ) const { return R; } const matrix& get_lower_constraints ( ) const { return lower; } const matrix& get_upper_constraints ( ) const { return upper; } void set_target ( const matrix& val, const unsigned long time ) { DLIB_ASSERT(time < horizon, "\t void mpc::set_target(eps_)" << "\n\t invalid inputs were given to this function" << "\n\t time: " << time << "\n\t horizon: " << horizon ); target[time] = val; } void set_target ( const matrix& val ) { for (unsigned long i = 0; i < horizon; ++i) target[i] = val; } void set_last_target ( const matrix& val ) { set_target(val, horizon-1); } const matrix& get_target ( const unsigned long time ) const { // make sure requires clause is not broken DLIB_ASSERT(time < horizon, "\t matrix mpc::get_target(eps_)" << "\n\t invalid inputs were given to this function" << "\n\t time: " << time << "\n\t horizon: " << horizon ); return target[time]; } double get_target_error_threshold ( ) const { return target_error_threshold; } void set_target_error_threshold ( const double thresh ) { target_error_threshold = thresh; } unsigned long get_max_iterations ( ) const { return max_iterations; } void set_max_iterations ( unsigned long max_iter ) { max_iterations = max_iter; } void set_epsilon ( double eps_ ) { // make sure requires clause is not broken DLIB_ASSERT(eps_ > 0, "\t void mpc::set_epsilon(eps_)" << "\n\t invalid inputs were given to this function" << "\n\t eps_: " << eps_ ); eps = eps_; } double get_epsilon ( ) const { return eps; } matrix operator() ( const matrix& current_state ) { // make sure requires clause is not broken DLIB_ASSERT(min(R) > 0 && A.nr() == current_state.size(), "\t matrix mpc::operator(current_state)" << "\n\t invalid inputs were given to this function" << "\n\t min(R): " << min(R) << "\n\t A.nr(): " << A.nr() << "\n\t current_state.size(): " << current_state.size() ); // Shift the inputs over by one time step so we can use them to warm start the // optimizer. for (unsigned long i = 1; i < horizon; ++i) controls[i-1] = controls[i]; solve_linear_mpc(current_state); for (unsigned long i = 1; i < horizon; ++i) target[i-1] = target[i]; return controls[0]; } private: // These temporary variables here just to avoid reallocating them on each call to // operator(). matrix M[horizon]; matrix MM[horizon]; matrix df[horizon]; matrix v[horizon]; matrix v_old[horizon]; void solve_linear_mpc ( const matrix& initial_state ) { // make it so MM == trans(K)*Q*(M-target) M[0] = A*initial_state + C; for (unsigned long i = 1; i < horizon; ++i) M[i] = A*M[i-1] + C; double min_error_seen = std::numeric_limits::infinity(); for (unsigned long i = 0; i < horizon; ++i) { M[i] = diagm(Q)*(M[i]-target[i]); if (target_error_threshold >= 0) { const double current_error = dot(M[i]-target[i], M[i]); min_error_seen = std::min(current_error, min_error_seen); // Once our trajectory gets us within target_error_threshold of the target at any time // then we essentially stop caring about what happens at times after that. This // gives us a "just hit the target, I don't care what happens after the hit" model. if (min_error_seen < target_error_threshold) { // Make it so all future errors now appear to be 0. E.g. it is as if target[i] // was equal to the state the current control sequence generates at time i. M[i] = 0; } } } for (long i = (long)horizon-2; i >= 0; --i) M[i] += trans(A)*M[i+1]; for (unsigned long i = 0; i < horizon; ++i) MM[i] = trans(B)*M[i]; unsigned long iter = 0; for (; iter < max_iterations; ++iter) { // compute current gradient and put it into df. // df == H*controls + MM; M[0] = B*controls[0]; for (unsigned long i = 1; i < horizon; ++i) M[i] = A*M[i-1] + B*controls[i]; for (unsigned long i = 0; i < horizon; ++i) M[i] = diagm(Q)*M[i]; for (long i = (long)horizon-2; i >= 0; --i) M[i] += trans(A)*M[i+1]; for (unsigned long i = 0; i < horizon; ++i) df[i] = MM[i] + trans(B)*M[i] + diagm(R)*controls[i]; // Check the stopping condition, which is the magnitude of the largest element // of the gradient. double max_df = 0; unsigned long max_t = 0; long max_v = 0; for (unsigned long i = 0; i < horizon; ++i) { for (long j = 0; j < controls[i].size(); ++j) { // if this variable isn't an active constraint then we care about it's // derivative. if (!((controls[i](j) <= lower(j) && df[i](j) > 0) || (controls[i](j) >= upper(j) && df[i](j) < 0))) { if (std::abs(df[i](j)) > max_df) { max_df = std::abs(df[i](j)); max_t = i; max_v = j; } } } } if (max_df < eps) break; // We will start out by doing a little bit of coordinate descent because it // allows us to optimize individual variables exactly. Since we are warm // starting each iteration with a really good solution this helps speed // things up a lot. const unsigned long smo_iters = 50; if (iter < smo_iters) { if (Q_diag[max_t](max_v) == 0) continue; // Take the optimal step but just for one variable. controls[max_t](max_v) = -(df[max_t](max_v)-Q_diag[max_t](max_v)*controls[max_t](max_v))/Q_diag[max_t](max_v); controls[max_t](max_v) = put_in_range(lower(max_v), upper(max_v), controls[max_t](max_v)); // If this is the last SMO iteration then don't forget to initialize v // for the gradient steps. if (iter+1 == smo_iters) { for (unsigned long i = 0; i < horizon; ++i) v[i] = controls[i]; } } else { // Take a projected gradient step. for (unsigned long i = 0; i < horizon; ++i) { v_old[i] = v[i]; v[i] = dlib::clamp(controls[i] - 1.0/lambda * df[i], lower, upper); controls[i] = dlib::clamp(v[i] + (std::sqrt(lambda)-1)/(std::sqrt(lambda)+1)*(v[i]-v_old[i]), lower, upper); } } } } unsigned long max_iterations; double eps; double target_error_threshold = -1; matrix A; matrix B; matrix C; matrix Q; matrix R; matrix lower; matrix upper; matrix target[horizon]; double lambda; // abound on the largest eigenvalue of the hessian matrix. matrix Q_diag[horizon]; matrix controls[horizon]; }; } #endif // DLIB_MPC_Hh_ ================================================ FILE: dlib/control/mpc_abstract.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_MPC_ABSTRACT_Hh_ #ifdef DLIB_MPC_ABSTRACT_Hh_ #include "../matrix.h" namespace dlib { template < long S_, long I_, unsigned long horizon_ > class mpc { /*! REQUIREMENTS ON horizon_ horizon_ > 0 REQUIREMENTS ON S_ S_ >= 0 REQUIREMENTS ON I_ I_ >= 0 WHAT THIS OBJECT REPRESENTS This object implements a linear model predictive controller. To explain what that means, suppose you have some process you want to control and the process dynamics are described by the linear equation: x_{i+1} = A*x_i + B*u_i + C That is, the next state the system goes into is a linear function of its current state (x_i) and the current control (u_i) plus some constant bias or disturbance. A model predictive controller can find the control (u) you should apply to drive the state (x) to some reference value, or alternatively to make the state track some reference time-varying sequence. It does this by simulating the process for horizon_ time steps and selecting the control that leads to the best performance over the next horizon_ steps. To be precise, each time you ask this object for a control, it solves the following quadratic program: min sum_i trans(x_i-target_i)*Q*(x_i-target_i) + trans(u_i)*R*u_i x_i,u_i such that: x_0 == current_state x_{i+1} == A*x_i + B*u_i + C lower <= u_i <= upper 0 <= i < horizon_ and reports u_0 as the control you should take given that you are currently in current_state. Q and R are user supplied matrices that define how we penalize variations away from the target state as well as how much we want to avoid generating large control signals. Finally, the algorithm we use to solve this quadratic program is based largely on the method described in: A Fast Gradient method for embedded linear predictive control (2011) by Markus Kogel and Rolf Findeisen !*/ public: const static long S = S_; const static long I = I_; const static unsigned long horizon = horizon_; mpc( ); /*! ensures - #get_max_iterations() == 0 - The A,B,C,Q,R,lower, and upper parameter matrices are filled with zeros. Therefore, to use this object you must initialize it via the constructor that supplies these parameters. - #get_target_error_threshold() == -1 !*/ mpc ( const matrix& A, const matrix& B, const matrix& C, const matrix& Q, const matrix& R, const matrix& lower, const matrix& upper ); /*! requires - A.nr() > 0 - B.nc() > 0 - A.nr() == A.nc() == B.nr() == C.nr() == Q.nr() - B.nc() == R.nr() == lower.nr() == upper.nr() - min(Q) >= 0 - min(R) > 0 - min(upper-lower) >= 0 ensures - #get_A() == A - #get_B() == B - #get_C() == C - #get_Q() == Q - #get_R() == R - #get_lower_constraints() == lower - #get_upper_constraints() == upper - for all valid i: - get_target(i) == a vector of all zeros - get_target(i).size() == A.nr() - #get_max_iterations() == 10000 - #get_epsilon() == 0.01 - #get_target_error_threshold() == -1 !*/ const matrix& get_A ( ) const; /*! ensures - returns the A matrix from the quadratic program defined above. !*/ const matrix& get_B ( ) const; /*! ensures - returns the B matrix from the quadratic program defined above. !*/ const matrix& get_C ( ) const; /*! ensures - returns the C matrix from the quadratic program defined above. !*/ const matrix& get_Q ( ) const; /*! ensures - returns the diagonal of the Q matrix from the quadratic program defined above. !*/ const matrix& get_R ( ) const; /*! ensures - returns the diagonal of the R matrix from the quadratic program defined above. !*/ const matrix& get_lower_constraints ( ) const; /*! ensures - returns the lower matrix from the quadratic program defined above. All controls generated by this object will have values no less than this lower bound. That is, any control u will satisfy min(u-lower) >= 0. !*/ const matrix& get_upper_constraints ( ) const; /*! ensures - returns the upper matrix from the quadratic program defined above. All controls generated by this object will have values no larger than this upper bound. That is, any control u will satisfy min(upper-u) >= 0. !*/ const matrix& get_target ( const unsigned long time ) const; /*! requires - time < horizon ensures - This object will try to find the control sequence that results in the process obtaining get_target(time) state at the indicated time. Note that the next time instant after "right now" is time 0. !*/ void set_target ( const matrix& val, const unsigned long time ); /*! requires - time < horizon ensures - #get_target(time) == val !*/ void set_target ( const matrix& val ); /*! ensures - for all valid t: - #get_target(t) == val !*/ void set_last_target ( const matrix& val ); /*! ensures - performs: set_target(val, horizon-1) !*/ double get_target_error_threshold ( ) const; /*! ensures - The target error terms in the objective function with values less than get_target_error_threshold() are ignored. That is, the trans(x_i-target_i)*Q*(x_i-target_i) terms with values less than this are dropped from the objective function. Therefore, setting get_target_error_threshold() to a value >= 0 allows you to encode a control law that says "find me the controls that make the target error less than or equal to this at some point, but I don't care what happens at times after that." !*/ void set_target_error_threshold ( const double thresh ); /*! ensures - #target_error_threshold() == thresh !*/ unsigned long get_max_iterations ( ) const; /*! ensures - When operator() is called it solves an optimization problem to get_epsilon() precision to determine the next control action. In particular, we run the optimizer until the magnitude of each element of the gradient vector is less than get_epsilon() or until get_max_iterations() solver iterations have been executed. !*/ void set_max_iterations ( unsigned long max_iter ); /*! ensures - #get_max_iterations() == max_iter !*/ void set_epsilon ( double eps ); /*! requires - eps > 0 ensures - #get_epsilon() == eps !*/ double get_epsilon ( ) const; /*! ensures - When operator() is called it solves an optimization problem to get_epsilon() precision to determine the next control action. In particular, we run the optimizer until the magnitude of each element of the gradient vector is less than get_epsilon() or until get_max_iterations() solver iterations have been executed. This means that smaller epsilon values will give more accurate outputs but may take longer to compute. !*/ matrix operator() ( const matrix& current_state ); /*! requires - min(R) > 0 - A.nr() == current_state.size() ensures - Solves the model predictive control problem defined by the arguments to this object's constructor, assuming that the starting state is given by current_state. Then we return the control that should be taken in the current state that best optimizes the quadratic objective function defined above. - We also shift over the target states so that you only need to update the last one (if you are using non-zero target states) via a call to set_last_target()). In particular, for all valid t, it will be the case that: - #get_target(t) == get_target(t+1) - #get_target(horizon-1) == get_target(horizon-1) !*/ }; } #endif // DLIB_MPC_ABSTRACT_Hh_ ================================================ FILE: dlib/control.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CONTRoL_ #define DLIB_CONTRoL_ #include "control/lspi.h" #include "control/mpc.h" #endif // DLIB_CONTRoL_ ================================================ FILE: dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_1.h ================================================ // Copyright (C) 2005 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CPP_PRETTY_PRINTER_KERNEl_1_ #define DLIB_CPP_PRETTY_PRINTER_KERNEl_1_ #include #include #include #include "cpp_pretty_printer_kernel_abstract.h" #include "../algs.h" namespace dlib { template < typename stack, typename tok > class cpp_pretty_printer_kernel_1 { /*! REQUIREMENTS ON stack must be an implementation of stack/stack_kernel_abstract.h and stack::type == unsigned long REQUIREMENTS ON tok must be an implementation of tokenizer/tokenizer_kernel_abstract.h INFO This implementation applies a color scheme, turns include directives such as #include "file.h" into links to file.h.html, and it also puts HTML anchor points on function and class declarations. !*/ public: cpp_pretty_printer_kernel_1 ( ); virtual ~cpp_pretty_printer_kernel_1 ( ); void print ( std::istream& in, std::ostream& out, const std::string& title ) const; void print_and_number ( std::istream& in, std::ostream& out, const std::string& title ) const; private: const std::string htmlify ( const std::string& str ) const; /*! ensures - str == str but with any '<' replaced with '<', any '>' replaced with '>', and any '&' replaced with '&' !*/ // data members mutable tok t; void number ( std::istream& in, std::ostream& out ) const; /*! ensures - prints in to out and adds line numbers !*/ // restricted functions cpp_pretty_printer_kernel_1(const cpp_pretty_printer_kernel_1&); // copy constructor cpp_pretty_printer_kernel_1& operator=(const cpp_pretty_printer_kernel_1&); // assignment operator }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename stack, typename tok > cpp_pretty_printer_kernel_1:: cpp_pretty_printer_kernel_1 ( ) { } // ---------------------------------------------------------------------------------------- template < typename stack, typename tok > cpp_pretty_printer_kernel_1:: ~cpp_pretty_printer_kernel_1 ( ) { } // ---------------------------------------------------------------------------------------- template < typename stack, typename tok > void cpp_pretty_printer_kernel_1:: print ( std::istream& in, std::ostream& out, const std::string& title ) const { if (!out) throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::print"); t.set_stream(in); out << "" << title << ""; if (!out) throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::print"); } // ---------------------------------------------------------------------------------------- template < typename stack, typename tok > void cpp_pretty_printer_kernel_1:: print_and_number ( std::istream& in, std::ostream& out, const std::string& title ) const { std::ostringstream sout; print(in,sout,title); std::istringstream sin(sout.str()); number(sin,out); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // private member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename stack, typename tok > void cpp_pretty_printer_kernel_1:: number ( std::istream& in, std::ostream& out ) const { if (!out) throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::number"); std::string space = "   "; std::ios::int_type ch; unsigned long count = 1; while ((ch=in.get()) != EOF) { if (ch != '\n') { out << (char)ch; } else { out << "\n" << count << " " + space; ++count; if (count == 10) space = "  "; if (count == 100) space = " "; if (count == 1000) space = ""; } } if (!out) throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::number"); } // ---------------------------------------------------------------------------------------- template < typename stack, typename tok > const std::string cpp_pretty_printer_kernel_1:: htmlify ( const std::string& str ) const { std::string::size_type i; std::string temp; for (i = 0; i < str.size(); ++i) { if (str[i] == '<') temp += "<"; else if (str[i] == '>') temp += ">"; else if (str[i] == '&') temp += "&"; else temp += str[i]; } return temp; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CPP_PRETTY_PRINTER_KERNEl_1_ ================================================ FILE: dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_2.h ================================================ // Copyright (C) 2005 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CPP_PRETTY_PRINTER_KERNEl_2_ #define DLIB_CPP_PRETTY_PRINTER_KERNEl_2_ #include #include #include #include "cpp_pretty_printer_kernel_abstract.h" #include "../algs.h" namespace dlib { template < typename stack, typename tok > class cpp_pretty_printer_kernel_2 { /*! REQUIREMENTS ON stack must be an implementation of stack/stack_kernel_abstract.h and stack::type == unsigned long REQUIREMENTS ON tok must be an implementation of tokenizer/tokenizer_kernel_abstract.h INFO This implementation applies a black and white color scheme suitable for printing on a black and white printer. It also places the document title prominently at the top of the pretty printed source file. !*/ public: cpp_pretty_printer_kernel_2 ( ); virtual ~cpp_pretty_printer_kernel_2 ( ); void print ( std::istream& in, std::ostream& out, const std::string& title ) const; void print_and_number ( std::istream& in, std::ostream& out, const std::string& title ) const; private: // data members mutable tok t; const std::string htmlify ( const std::string& str ) const; /*! ensures - str == str but with any '<' replaced with '<', any '>' replaced with '>', and any '&' replaced with '&' !*/ void number ( std::istream& in, std::ostream& out ) const; /*! ensures - prints in to out and adds line numbers !*/ // restricted functions cpp_pretty_printer_kernel_2(const cpp_pretty_printer_kernel_2&); // copy constructor cpp_pretty_printer_kernel_2& operator=(const cpp_pretty_printer_kernel_2&); // assignment operator }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename stack, typename tok > cpp_pretty_printer_kernel_2:: cpp_pretty_printer_kernel_2 ( ) { } // ---------------------------------------------------------------------------------------- template < typename stack, typename tok > cpp_pretty_printer_kernel_2:: ~cpp_pretty_printer_kernel_2 ( ) { } // ---------------------------------------------------------------------------------------- template < typename stack, typename tok > void cpp_pretty_printer_kernel_2:: print ( std::istream& in, std::ostream& out, const std::string& title ) const { if (!out) throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::print"); t.set_stream(in); out << "" << "" << title << "" << "

" << title << "

\n"
            << "\n";
        if (!out)
            throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::print");

        unsigned long scope = 0; // counts the number of new scopes we have entered 
                        // since we were at a scope where functions can be declared

        bool recently_seen_class_keyword = false;
            // true if we have seen the keywords class or struct and
            // we have not seen any identifiers or { characters

        bool recently_seen_include = false;
            // true if we have seen the #include keyword and have not seen double
            // quoted text or >

        bool recently_seen_new_scope = false;  
            // true if we have seen the keywords class, namespace, or struct and
            // we have not seen the characters {, ), or ; since then

        bool recently_seen_paren = false;
            // true if we have seen a ) and we have only seen white_space or comments since

        bool in_initialization_list = false;
            // true if we have seen a ) followed by any white space or comments and then
            // followed by a : (in scope==0 with recently_seen_preprocessor==false) and we 
            // have not yet seen the character { or ;

        bool recently_seen_preprocessor = false;
            // true if we have seen the #pragma or #if or #define or #elif keyword and 
            // have not seen an identifier.


        bool recently_seen_extern = false;
            // true if we have seen the extern keyword and haven't yet seen a 
            // { or ; character.

        unsigned long paren_count = 0; 
            // this is the number of ( we have seen minus the number of ) we have
            // seen.
            

        int type;
        stack scopes; // a stack to hold old scopes
        std::string token, temp;
        t.get_token(type,token);
        while (type != tok::END_OF_FILE)
        {
            switch (type)
            {
            case tok::IDENTIFIER: // ------------------------------------------
                if ( recently_seen_class_keyword)
                {
                    // this might be a class name so check if there is a 
                    // ; or identifier or * or & coming up.
                    type = t.peek_type();
                    temp.clear();
                    if (type == tok::WHITE_SPACE)
                    {
                        t.get_token(type,temp);
                        if (temp.find_first_of("\n\r") != std::string::npos)
                            recently_seen_preprocessor = false;
                    }
                    if (t.peek_token() != ";" && t.peek_type() != tok::IDENTIFIER &&
                        t.peek_token() != "*" && t.peek_token() != "&")
                    {
                        // this is the name of a class or struct in a class or
                        // struct declaration.
                        out << "" << token << "" << temp;
                    }
                    else
                    {
                        out << token << temp;
                    }
                }
                else if ( !in_initialization_list &&
                     !recently_seen_preprocessor &&
                     scope == 0 &&
                     paren_count == 0)
                {
                    // this might be a function name so check if there is a 
                    // ( coming up.
                    type = t.peek_type();
                    temp.clear();
                    if (type == tok::WHITE_SPACE)
                    {
                        t.get_token(type,temp);
                        type = t.peek_type();
                    }
                    if (type == tok::OTHER && t.peek_token() == "(")
                    {
                        // this is a function definition or prototype
                        out << "" << token << "" << temp;
                    }
                    else
                    {
                        out << token << temp;
                    }
                }
                else
                {
                    out << token;
                }
                


                recently_seen_class_keyword = false;
                recently_seen_paren = false;
                break;

            case tok::KEYWORD: // ---------------------------------------------
                if (scope == 0 && token == "operator")
                {
                    // Doing this is sort of weird since operator is really a keyword
                    // but I just like how this looks.
                    out << "" << token << "";
                }
                // this isn't a keyword if it is something like #include 
                else if (!recently_seen_include) 
                {
                    // This is a normal keyword
                    out << "" << token << "";
                }
                else
                {
                    out << token;
                }

                if (token == "#include") 
                {
                    recently_seen_include = true;
                }
                else if (token == "class")
                {
                    recently_seen_new_scope = true;
                    recently_seen_class_keyword = true;
                }
                else if (token == "namespace")
                {
                    recently_seen_new_scope = true;
                }
                else if (token == "struct")
                {
                    recently_seen_new_scope = true;
                    recently_seen_class_keyword = true;
                }
                else if (token == "#pragma" || token == "#define" || token == "#elif" || token == "#if")
                {
                    recently_seen_preprocessor = true;
                }
                else if (token == "extern")
                {
                    recently_seen_extern = true;
                }
                recently_seen_paren = false;
                break;

            case tok::COMMENT: // ---------------------------------------------
                {
                    out << "" << htmlify(token) << "";
                }
                break;

            case tok::SINGLE_QUOTED_TEXT: // ----------------------------------
                {
                    out << htmlify(token);
                    recently_seen_paren = false;
                }
                break;

            case tok::WHITE_SPACE: // -----------------------------------------
                {
                    out << token;
                    if (token.find_first_of("\n\r") != std::string::npos)
                        recently_seen_preprocessor = false;
                }
                break;

            case tok::DOUBLE_QUOTED_TEXT: // ----------------------------------
                {                    
                    out << htmlify(token);
                    recently_seen_paren = false;
                    recently_seen_include = false;
                }
                break;

            case tok::NUMBER:
            case tok::OTHER: // -----------------------------------------------               
                switch (token[0])
                {
                case '{':
                    out << "{";  
                    // if we are entering a new scope
                    if (recently_seen_new_scope || recently_seen_extern)
                    {
                        recently_seen_new_scope = false;
                        scopes.push(scope);
                        scope = 0;
                    }
                    else
                    {
                        ++scope;
                    }
                    in_initialization_list = false;
                    recently_seen_paren = false;
                    recently_seen_class_keyword = false;
                    recently_seen_extern = false;
                    break;
                case '}':
                    out << "}";
                    if (scope > 0)
                    {
                        --scope;
                    }
                    else if (scopes.size())
                    {
                        scopes.pop(scope);
                    }
                    recently_seen_paren = false;
                    break;

                case ':':
                    out << ':';
                    if (recently_seen_paren && scope == 0 &&
                        recently_seen_preprocessor == false)
                    {
                        in_initialization_list = true;
                    }
                    recently_seen_paren = false;
                    break;

                case ';': 
                    out << ';';
                    recently_seen_new_scope = false;
                    recently_seen_paren = false;
                    recently_seen_extern = false;
                    break;

                case ')':
                    out << ')';
                    recently_seen_paren = true;
                    recently_seen_new_scope = false;
                    --paren_count;
                    break;

                case '(':
                    out << '(';
                    recently_seen_paren = false;
                    ++paren_count;
                    break;

                case '>':
                    recently_seen_include = false;
                    out << ">";
                    recently_seen_paren = true;
                    break;

                case '<':
                    out << "<";
                    recently_seen_paren = true;
                    break;

                case '&':
                    out << "&";
                    recently_seen_paren = true;
                    break;

                default:
                    out << token;
                    recently_seen_paren = false;
                    if (token == ">")
                        recently_seen_include = false;
                    break;

                } // switch (token[0])
                break;

            } // switch (type)

            t.get_token(type,token);
        } // while (type != tok::END_OF_FILE)


        out << "
"; if (!out) throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::print"); } // ---------------------------------------------------------------------------------------- template < typename stack, typename tok > void cpp_pretty_printer_kernel_2:: print_and_number ( std::istream& in, std::ostream& out, const std::string& title ) const { std::ostringstream sout; print(in,sout,title); std::istringstream sin(sout.str()); number(sin,out); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // private member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename stack, typename tok > void cpp_pretty_printer_kernel_2:: number ( std::istream& in, std::ostream& out ) const { if (!out) throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::number"); std::string space = "   "; std::ios::int_type ch; unsigned long count = 1; while ((ch=in.get()) != EOF) { if (ch != '\n') { out << (char)ch; } else { out << "\n" << count << " " + space; ++count; if (count == 10) space = "  "; if (count == 100) space = " "; if (count == 1000) space = ""; } } if (!out) throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_2::number"); } // ---------------------------------------------------------------------------------------- template < typename stack, typename tok > const std::string cpp_pretty_printer_kernel_2:: htmlify ( const std::string& str ) const { std::string::size_type i; std::string temp; for (i = 0; i < str.size(); ++i) { if (str[i] == '<') temp += "<"; else if (str[i] == '>') temp += ">"; else if (str[i] == '&') temp += "&"; else temp += str[i]; } return temp; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CPP_PRETTY_PRINTER_KERNEl_2_ ================================================ FILE: dlib/cpp_pretty_printer/cpp_pretty_printer_kernel_abstract.h ================================================ // Copyright (C) 2005 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_CPP_PRETTY_PRINTER_KERNEl_ABSTRACT_ #ifdef DLIB_CPP_PRETTY_PRINTER_KERNEl_ABSTRACT_ #include #include namespace dlib { class cpp_pretty_printer { /*! INITIAL VALUE This object does not have any state associated with it. WHAT THIS OBJECT REPRESENTS This object represents an HTML pretty printer for C++ source code. !*/ public: cpp_pretty_printer ( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc !*/ virtual ~cpp_pretty_printer ( ); /*! ensures - any resources associated with *this have been released !*/ void print ( std::istream& in, std::ostream& out, const std::string& title ) const; /*! ensures - treats data from in as C++ source code and pretty prints it in HTML and writes it to out. - The title of the HTML document written to out will be title throws - std::ios_base::failure If there was a problem writing to out then this exception will be thrown. - any other exception This exception may be thrown if there is any other problem. !*/ void print_and_number ( std::istream& in, std::ostream& out, const std::string& title ) const; /*! ensures - treats data from in as C++ source code and pretty prints it in HTML with line numbers and writes it to out. - The title of the HTML document written to out will be title throws - std::ios_base::failure If there was a problem writing to out then this exception will be thrown. - any other exception This exception may be thrown if there is any other problem. !*/ private: // restricted functions cpp_pretty_printer(const cpp_pretty_printer&); // copy constructor cpp_pretty_printer& operator=(const cpp_pretty_printer&); // assignment operator }; } #endif // DLIB_CPP_PRETTY_PRINTER_KERNEl_ABSTRACT_ ================================================ FILE: dlib/cpp_pretty_printer.h ================================================ // Copyright (C) 2005 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CPP_PRETTY_PRINTEr_ #define DLIB_CPP_PRETTY_PRINTEr_ #include "cpp_pretty_printer/cpp_pretty_printer_kernel_1.h" #include "cpp_pretty_printer/cpp_pretty_printer_kernel_2.h" #include "cpp_tokenizer.h" #include "stack.h" namespace dlib { class cpp_pretty_printer { cpp_pretty_printer() {} typedef stack::kernel_1a stack; typedef cpp_tokenizer::kernel_1a tok; public: //----------- kernels --------------- // kernel_1a typedef cpp_pretty_printer_kernel_1 kernel_1a; // kernel_2a typedef cpp_pretty_printer_kernel_2 kernel_2a; }; } #endif // DLIB_CPP_PRETTY_PRINTEr_ ================================================ FILE: dlib/cpp_tokenizer/cpp_tokenizer_kernel_1.h ================================================ // Copyright (C) 2005 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CPP_TOKENIZER_KERNEl_1_ #define DLIB_CPP_TOKENIZER_KERNEl_1_ #include #include #include "cpp_tokenizer_kernel_abstract.h" #include "../algs.h" namespace dlib { namespace cpp_tok_kernel_1_helper { struct token_text_pair { std::string token; int type=0; }; } template < typename tok, typename queue, typename set > class cpp_tokenizer_kernel_1 { /*! REQUIREMENTS ON tok tok must be an implementation of tokenizer/tokenizer_kernel_abstract.h REQUIREMENTS ON queue queue must be an implementation of queue/queue_kernel_abstract.h and must have T==cpp_tok_kernel_1_helper::token_text_pair REQUIREMENTS ON set set must be an implemention of set/set_kernel_abstract.h or hash_set/hash_set_kernel_abstract.h and must have T==std::string. INITIAL VALUE - keywords == a set of all the C++ keywords - tokenizer.stream_is_set() == false - buffer.size() == 0 - tokenizer.get_identifier_head() == "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters() - tokenizer.get_identifier_body() == "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters() + tokenizer.numbers() - have_peeked == false CONVENTION - tokenizer.stream_is_set() == stream_is_set() - tokenizer.get_stream() == get_stream() - keywords == a set of all the C++ keywords - tokenizer.get_identifier_head() == "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters() - tokenizer.get_identifier_body() == "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters() + tokenizer.numbers() - buffer == a queue of tokens. This is where we put tokens we gathered early due to looking ahead. - if (have_peeked) then - next_token == the next token to be returned from get_token() - next_type == the type of token in peek_token !*/ typedef cpp_tok_kernel_1_helper::token_text_pair token_text_pair; public: enum { END_OF_FILE, KEYWORD, COMMENT, SINGLE_QUOTED_TEXT, DOUBLE_QUOTED_TEXT, IDENTIFIER, OTHER, NUMBER, WHITE_SPACE }; cpp_tokenizer_kernel_1 ( ); virtual ~cpp_tokenizer_kernel_1 ( ); void clear( ); void set_stream ( std::istream& in ); bool stream_is_set ( ) const; std::istream& get_stream ( ) const; void get_token ( int& type, std::string& token ); int peek_type ( ) const; const std::string& peek_token ( ) const; void swap ( cpp_tokenizer_kernel_1& item ); private: void buffer_token( int type, const std::string& token ) /*! ensures - stores the token and its type into buffer !*/ { token_text_pair temp; temp.token = token; temp.type = type; buffer.enqueue(temp); } void buffer_token( int type, char token ) /*! ensures - stores the token and its type into buffer !*/ { token_text_pair temp; temp.token = token; temp.type = type; buffer.enqueue(temp); } // restricted functions cpp_tokenizer_kernel_1(const cpp_tokenizer_kernel_1&); // copy constructor cpp_tokenizer_kernel_1& operator=(const cpp_tokenizer_kernel_1&); // assignment operator // data members set keywords; queue buffer; tok tokenizer; mutable std::string next_token; mutable int next_type; mutable bool have_peeked; }; template < typename tok, typename queue, typename set > inline void swap ( cpp_tokenizer_kernel_1& a, cpp_tokenizer_kernel_1& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename tok, typename queue, typename set > cpp_tokenizer_kernel_1:: cpp_tokenizer_kernel_1( ) : have_peeked(false) { // add C++ keywords to keywords std::string temp; temp = "#include"; keywords.add(temp); temp = "__asm"; keywords.add(temp); temp = "_asm"; keywords.add(temp); temp = "if"; keywords.add(temp); temp = "int"; keywords.add(temp); temp = "else"; keywords.add(temp); temp = "template"; keywords.add(temp); temp = "void"; keywords.add(temp); temp = "false"; keywords.add(temp); temp = "class"; keywords.add(temp); temp = "public"; keywords.add(temp); temp = "while"; keywords.add(temp); temp = "bool"; keywords.add(temp); temp = "new"; keywords.add(temp); temp = "delete"; keywords.add(temp); temp = "true"; keywords.add(temp); temp = "typedef"; keywords.add(temp); temp = "const"; keywords.add(temp); temp = "virtual"; keywords.add(temp); temp = "inline"; keywords.add(temp); temp = "for"; keywords.add(temp); temp = "break"; keywords.add(temp); temp = "struct"; keywords.add(temp); temp = "float"; keywords.add(temp); temp = "case"; keywords.add(temp); temp = "enum"; keywords.add(temp); temp = "this"; keywords.add(temp); temp = "typeid"; keywords.add(temp); temp = "double"; keywords.add(temp); temp = "char"; keywords.add(temp); temp = "typename"; keywords.add(temp); temp = "signed"; keywords.add(temp); temp = "friend"; keywords.add(temp); temp = "wint_t"; keywords.add(temp); temp = "default"; keywords.add(temp); temp = "asm"; keywords.add(temp); temp = "reinterpret_cast"; keywords.add(temp); temp = "#define"; keywords.add(temp); temp = "do"; keywords.add(temp); temp = "continue"; keywords.add(temp); temp = "auto"; keywords.add(temp); temp = "unsigned"; keywords.add(temp); temp = "size_t"; keywords.add(temp); temp = "#undef"; keywords.add(temp); temp = "#pragma"; keywords.add(temp); temp = "namespace"; keywords.add(temp); temp = "private"; keywords.add(temp); temp = "#endif"; keywords.add(temp); temp = "catch"; keywords.add(temp); temp = "#else"; keywords.add(temp); temp = "register"; keywords.add(temp); temp = "volatile"; keywords.add(temp); temp = "const_cast"; keywords.add(temp); temp = "#end"; keywords.add(temp); temp = "mutable"; keywords.add(temp); temp = "static_cast"; keywords.add(temp); temp = "wchar_t"; keywords.add(temp); temp = "#if"; keywords.add(temp); temp = "protected"; keywords.add(temp); temp = "throw"; keywords.add(temp); temp = "using"; keywords.add(temp); temp = "dynamic_cast"; keywords.add(temp); temp = "#ifdef"; keywords.add(temp); temp = "return"; keywords.add(temp); temp = "short"; keywords.add(temp); temp = "#error"; keywords.add(temp); temp = "#line"; keywords.add(temp); temp = "explicit"; keywords.add(temp); temp = "union"; keywords.add(temp); temp = "#ifndef"; keywords.add(temp); temp = "try"; keywords.add(temp); temp = "sizeof"; keywords.add(temp); temp = "goto"; keywords.add(temp); temp = "long"; keywords.add(temp); temp = "#elif"; keywords.add(temp); temp = "static"; keywords.add(temp); temp = "operator"; keywords.add(temp); temp = "switch"; keywords.add(temp); temp = "extern"; keywords.add(temp); // set the tokenizer's IDENTIFIER token for C++ identifiers tokenizer.set_identifier_token( "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters(), "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters() + tokenizer.numbers() ); } // ---------------------------------------------------------------------------------------- template < typename tok, typename queue, typename set > cpp_tokenizer_kernel_1:: ~cpp_tokenizer_kernel_1 ( ) { } // ---------------------------------------------------------------------------------------- template < typename tok, typename queue, typename set > void cpp_tokenizer_kernel_1:: clear( ) { tokenizer.clear(); buffer.clear(); have_peeked = false; // set the tokenizer's IDENTIFIER token for C++ identifiers tokenizer.set_identifier_token( "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters(), "$_" + tokenizer.lowercase_letters() + tokenizer.uppercase_letters() + tokenizer.numbers() ); } // ---------------------------------------------------------------------------------------- template < typename tok, typename queue, typename set > void cpp_tokenizer_kernel_1:: set_stream ( std::istream& in ) { tokenizer.set_stream(in); buffer.clear(); have_peeked = false; } // ---------------------------------------------------------------------------------------- template < typename tok, typename queue, typename set > bool cpp_tokenizer_kernel_1:: stream_is_set ( ) const { return tokenizer.stream_is_set(); } // ---------------------------------------------------------------------------------------- template < typename tok, typename queue, typename set > std::istream& cpp_tokenizer_kernel_1:: get_stream ( ) const { return tokenizer.get_stream(); } // ---------------------------------------------------------------------------------------- template < typename tok, typename queue, typename set > void cpp_tokenizer_kernel_1:: get_token ( int& type, std::string& token ) { if (!have_peeked) { if (buffer.size() > 0) { // just return what is in the buffer token_text_pair temp; buffer.dequeue(temp); type = temp.type; token = temp.token; return; } tokenizer.get_token(type,token); switch (type) { case tok::END_OF_FILE: { type = END_OF_FILE; } break; case tok::END_OF_LINE: case tok::WHITE_SPACE: { type = tokenizer.peek_type(); if (type == tok::END_OF_LINE || type == tok::WHITE_SPACE) { std::string temp; do { tokenizer.get_token(type,temp); token += temp; type = tokenizer.peek_type(); }while (type == tok::END_OF_LINE || type == tok::WHITE_SPACE); } type = WHITE_SPACE; } break; case tok::NUMBER: { // this could be a hex number such as 0xa33. we should check for this. if (tokenizer.peek_type() == tok::IDENTIFIER && token == "0" && (tokenizer.peek_token()[0] == 'x' || tokenizer.peek_token()[0] == 'X')) { // this is a hex number so accumulate all the numbers and identifiers that follow // because they have to be part of the number std::string temp; tokenizer.get_token(type,temp); token = "0" + temp; // get the rest of the hex number while (tokenizer.peek_type() == tok::IDENTIFIER || tokenizer.peek_type() == tok::NUMBER ) { tokenizer.get_token(type,temp); token += temp; } } // or this could be a floating point value or something with an 'e' or 'E' in it. else if ((tokenizer.peek_type() == tok::CHAR && tokenizer.peek_token()[0] == '.') || (tokenizer.peek_type() == tok::IDENTIFIER && std::tolower(tokenizer.peek_token()[0]) == 'e')) { std::string temp; tokenizer.get_token(type,temp); token += temp; // now get the rest of the floating point value while (tokenizer.peek_type() == tok::IDENTIFIER || tokenizer.peek_type() == tok::NUMBER ) { tokenizer.get_token(type,temp); token += temp; } } type = NUMBER; } break; case tok::IDENTIFIER: { if (keywords.is_member(token)) { type = KEYWORD; } else { type = IDENTIFIER; } } break; case tok::CHAR: type = OTHER; switch (token[0]) { case '#': { // this might be a preprocessor keyword so we should check the // next token if (tokenizer.peek_type() == tok::IDENTIFIER && keywords.is_member('#'+tokenizer.peek_token())) { tokenizer.get_token(type,token); token = '#' + token; type = KEYWORD; } else { token = '#'; type = OTHER; } } break; case '"': { std::string temp; tokenizer.get_token(type,token); while (type != tok::END_OF_FILE) { // if this is the end of the quoted string if (type == tok::CHAR && token[0] == '"' && (temp.size() == 0 || temp[temp.size()-1] != '\\' || (temp.size() > 1 && temp[temp.size()-2] == '\\') )) { buffer_token(DOUBLE_QUOTED_TEXT,temp); buffer_token(OTHER,"\""); break; } else { temp += token; } tokenizer.get_token(type,token); } type = OTHER; token = '"'; } break; case '\'': { std::string temp; tokenizer.get_token(type,token); if (type == tok::CHAR && token[0] == '\\') { temp += '\\'; tokenizer.get_token(type,token); } temp += token; buffer_token(SINGLE_QUOTED_TEXT,temp); // The next character should be a ' so take it out and put it in // the buffer. tokenizer.get_token(type,token); buffer_token(OTHER,token); type = OTHER; token = '\''; } break; case '/': { // look ahead to see if this is the start of a comment if (tokenizer.peek_type() == tok::CHAR) { if (tokenizer.peek_token()[0] == '/') { tokenizer.get_token(type,token); // this is the start of a line comment token = "//"; std::string temp; tokenizer.get_token(type,temp); while (type != tok::END_OF_FILE) { // if this is the end of the comment if (type == tok::END_OF_LINE && token[token.size()-1] != '\\' ) { token += '\n'; break; } else { token += temp; } tokenizer.get_token(type,temp); } type = COMMENT; } else if (tokenizer.peek_token()[0] == '*') { tokenizer.get_token(type,token); // this is the start of a block comment token = "/*"; std::string temp; tokenizer.get_token(type,temp); while (type != tok::END_OF_FILE) { // if this is the end of the comment if (type == tok::CHAR && temp[0] == '/' && token[token.size()-1] == '*') { token += '/'; break; } else { token += temp; } tokenizer.get_token(type,temp); } type = COMMENT; } } } break; default: break; } // switch (token[0]) } // switch (type) } else { // if we get this far it means we have peeked so we should // return the peek data. type = next_type; token = next_token; have_peeked = false; } } // ---------------------------------------------------------------------------------------- template < typename tok, typename queue, typename set > int cpp_tokenizer_kernel_1:: peek_type ( ) const { const_cast*>(this)->get_token(next_type,next_token); have_peeked = true; return next_type; } // ---------------------------------------------------------------------------------------- template < typename tok, typename queue, typename set > const std::string& cpp_tokenizer_kernel_1:: peek_token ( ) const { const_cast*>(this)->get_token(next_type,next_token); have_peeked = true; return next_token; } // ---------------------------------------------------------------------------------------- template < typename tok, typename queue, typename set > void cpp_tokenizer_kernel_1:: swap ( cpp_tokenizer_kernel_1& item ) { tokenizer.swap(item.tokenizer); buffer.swap(item.buffer); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CPP_TOKENIZER_KERNEl_1_ ================================================ FILE: dlib/cpp_tokenizer/cpp_tokenizer_kernel_abstract.h ================================================ // Copyright (C) 2005 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_CPP_TOKENIZER_KERNEl_ABSTRACT_ #ifdef DLIB_CPP_TOKENIZER_KERNEl_ABSTRACT_ #include #include namespace dlib { class cpp_tokenizer { /*! INITIAL VALUE stream_is_set() == false WHAT THIS OBJECT REPRESENTS This object represents a simple tokenizer for C++ source code. BUFFERING This object is allowed to buffer data from the input stream. Thus if you clear it or switch streams (via calling set_stream()) any buffered data will be lost. TOKENS When picking out tokens the cpp_tokenizer will always extract the longest token it can. For example, if faced with the string "AAA" it will consider the three As to be a single IDENTIFIER token not three smaller IDENTIFIER tokens. Also note that no characters in the input stream are discarded. They will all be returned in the text of some token. Additionally, each character will never be returned more than once. This means that if you concatenated all returned tokens it would exactly reproduce the contents of the input stream. The tokens are defined as follows: END_OF_FILE This token represents the end of file. It doesn't have any actual characters associated with it. KEYWORD This token matches a C++ keyword. (This includes the preprocessor directives). COMMENT This token matches a C++ comment. SINGLE_QUOTED_TEXT This token matches the text of any single quoted literal. For example, 'a' would be a match and the text of this token would be the single character a. DOUBLE_QUOTED_TEXT This token matches the text of any double quoted string. For example, "C++" would be a match and the text of this token would be the three character string C++. WHITE_SPACE This is a multi character token. It is defined as a sequence of one or more spaces, carrage returns, newlines, and tabs. I.e. It is composed of characters from the following string " \r\n\t". IDENTIFIER This token matches any C++ identifier that isn't matched by any of the above tokens. (A C++ identifier being a string matching the regular expression [_$a-zA-Z][_$a-zA-Z0-9]*). NUMBER This token matches any C++ numerical constant. OTHER This matches anything that isn't part of one of the above tokens. It is always a single character. !*/ public: enum { END_OF_FILE, KEYWORD, COMMENT, SINGLE_QUOTED_TEXT, DOUBLE_QUOTED_TEXT, IDENTIFIER, OTHER, NUMBER, WHITE_SPACE }; cpp_tokenizer ( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc !*/ virtual ~cpp_tokenizer ( ); /*! ensures - any resources associated with *this have been released !*/ void clear( ); /*! ensures - #*this has its initial value throws - std::bad_alloc If this exception is thrown then #*this is unusable until clear() is called and succeeds. !*/ void set_stream ( std::istream& in ); /*! ensures - #*this will read data from in and tokenize it - #stream_is_set() == true - #get_stream() == in !*/ bool stream_is_set ( ) const; /*! ensures - returns true if a stream has been associated with *this by calling set_stream() !*/ std::istream& get_stream ( ) const; /*! requires - stream_is_set() == true ensures - returns a reference to the istream object that *this is reading from. !*/ void get_token ( int& type, std::string& token ); /*! requires - stream_is_set() == true ensures - #token == the next token from the input stream get_stream() - #type == the type of the token in #token throws - bad_alloc If this exception is thrown then the call to this function will have no effect on *this but the values of #type and #token will be undefined. Additionally, some characters may have been read from the stream get_stream() and lost. !*/ int peek_type ( ) const; /*! requires - stream_is_set() == true ensures - returns the type of the token that will be returned from the next call to get_token() throws - bad_alloc If this exception is thrown then the call to this function will have no effect on *this. However, some characters may have been read from the stream get_stream() and lost. !*/ const std::string& peek_token ( ) const; /*! requires - stream_is_set() == true ensures - returns the text of the token that will be returned from the next call to get_token() throws - bad_alloc If this exception is thrown then the call to this function will have no effect on *this. However, some characters may have been read from the stream get_stream() and lost. !*/ void swap ( cpp_tokenizer& item ); /*! ensures - swaps *this and item !*/ private: // restricted functions cpp_tokenizer(const cpp_tokenizer&); // copy constructor cpp_tokenizer& operator=(const cpp_tokenizer&); // assignment operator }; inline void swap ( cpp_tokenizer& a, cpp_tokenizer& b ) { a.swap(b); } /*! provides a global swap function !*/ } #endif // DLIB_CPP_TOKENIZER_KERNEl_ABSTRACT_ ================================================ FILE: dlib/cpp_tokenizer/cpp_tokenizer_kernel_c.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CPP_TOKENIZER_KERNEl_C_ #define DLIB_CPP_TOKENIZER_KERNEl_C_ #include "cpp_tokenizer_kernel_abstract.h" #include "../assert.h" #include #include namespace dlib { template < typename tokenizer > class cpp_tokenizer_kernel_c : public tokenizer { public: std::istream& get_stream ( ) const; void get_token ( int& type, std::string& token ); int peek_type ( ) const; const std::string& peek_token ( ) const; }; template < typename tokenizer > inline void swap ( cpp_tokenizer_kernel_c& a, cpp_tokenizer_kernel_c& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename tokenizer > std::istream& cpp_tokenizer_kernel_c:: get_stream ( ) const { // make sure requires clause is not broken DLIB_CASSERT( this->stream_is_set() == true, "\tstd::istream& cpp_tokenizer::get_stream()" << "\n\tyou must set a stream for this object before you can get it" << "\n\tthis: " << this ); // call the real function return tokenizer::get_stream(); } // ---------------------------------------------------------------------------------------- template < typename tokenizer > const std::string& cpp_tokenizer_kernel_c:: peek_token ( ) const { // make sure requires clause is not broken DLIB_CASSERT( this->stream_is_set() == true, "\tconst std::string& cpp_tokenizer::peek_token()" << "\n\tyou must set a stream for this object before you can peek at what it contains" << "\n\tthis: " << this ); // call the real function return tokenizer::peek_token(); } // ---------------------------------------------------------------------------------------- template < typename tokenizer > int cpp_tokenizer_kernel_c:: peek_type ( ) const { // make sure requires clause is not broken DLIB_CASSERT( this->stream_is_set() == true, "\tint cpp_tokenizer::peek_type()" << "\n\tyou must set a stream for this object before you can peek at what it contains" << "\n\tthis: " << this ); // call the real function return tokenizer::peek_type(); } // ---------------------------------------------------------------------------------------- template < typename tokenizer > void cpp_tokenizer_kernel_c:: get_token ( int& type, std::string& token ) { // make sure requires clause is not broken DLIB_CASSERT( this->stream_is_set() == true, "\tvoid cpp_tokenizer::get_token()" << "\n\tyou must set a stream for this object before you can get tokens from it." << "\n\tthis: " << this ); // call the real function tokenizer::get_token(type,token); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_TOKENIZER_KERNEl_C_ ================================================ FILE: dlib/cpp_tokenizer.h ================================================ // Copyright (C) 2005 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CPP_TOKENIZEr_ #define DLIB_CPP_TOKENIZEr_ #include #include "cpp_tokenizer/cpp_tokenizer_kernel_1.h" #include "cpp_tokenizer/cpp_tokenizer_kernel_c.h" #include "tokenizer.h" #include "queue.h" #include "set.h" namespace dlib { class cpp_tokenizer { cpp_tokenizer() {} typedef set::kernel_1a set; typedef queue::kernel_2a queue; typedef tokenizer::kernel_1a tok; public: //----------- kernels --------------- // kernel_1a typedef cpp_tokenizer_kernel_1 kernel_1a; typedef cpp_tokenizer_kernel_c kernel_1a_c; }; } #endif // DLIB_CPP_TOKENIZEr_ ================================================ FILE: dlib/crc32/crc32_kernel_1.h ================================================ // Copyright (C) 2005 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CRC32_KERNEl_1_ #define DLIB_CRC32_KERNEl_1_ #include "../algs.h" #include #include #include "crc32_kernel_abstract.h" namespace dlib { class crc32 { /*! INITIAL VALUE checksum == 0xFFFFFFFF CONVENTION get_checksum() == checksum ^ 0xFFFFFFFF !*/ public: // this is here for backwards compatibility with older versions of dlib. typedef crc32 kernel_1a; inline crc32 ( ); inline crc32 ( const std::string& item ); inline crc32 ( const std::vector& item ); inline virtual ~crc32 ( ); inline void clear( ); inline void add ( unsigned char item ); inline void add ( const std::string& item ); inline void add ( const std::vector& item ); inline operator unsigned long ( ) const { return get_checksum(); } inline unsigned long get_checksum ( ) const; inline void swap ( crc32& item ); private: unsigned long checksum; inline unsigned long table ( unsigned int idx ) const { /* // This code generates the crc_table used below. unsigned long crc_table[256]; for (unsigned long i = 0; i < 256; ++i) { unsigned long temp = i; for (unsigned long j = 0; j < 8; ++j) { if (temp&1) temp = (temp>>1)^0xedb88320; else temp >>= 1; } crc_table[i] = temp; std::cout << std::hex << crc_table[i] << std::endl; } */ const static unsigned long crc_table[256] = { 0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x76dc419, 0x706af48f, 0xe963a535, 0x9e6495a3, 0xedb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, 0x9b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91, 0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7, 0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5, 0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b, 0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59, 0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f, 0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d, 0x76dc4190, 0x1db7106, 0x98d220bc, 0xefd5102a, 0x71b18589, 0x6b6b51f, 0x9fbfe4a5, 0xe8b8d433, 0x7807c9a2, 0xf00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x86d3d2d, 0x91646c97, 0xe6635c01, 0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457, 0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65, 0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb, 0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9, 0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f, 0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad, 0xedb88320, 0x9abfb3b6, 0x3b6e20c, 0x74b1d29a, 0xead54739, 0x9dd277af, 0x4db2615, 0x73dc1683, 0xe3630b12, 0x94643b84, 0xd6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0xa00ae27, 0x7d079eb1, 0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7, 0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5, 0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b, 0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79, 0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f, 0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d, 0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x26d930a, 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x5005713, 0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0xcb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0xbdbdf21, 0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777, 0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45, 0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db, 0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9, 0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf, 0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d }; return crc_table[idx]; } }; inline void swap ( crc32& a, crc32& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- crc32:: crc32 ( ) { checksum = 0xFFFFFFFF; } // ---------------------------------------------------------------------------------------- crc32:: crc32 ( const std::string& item ) { checksum = 0xFFFFFFFF; add(item); } // ---------------------------------------------------------------------------------------- crc32:: crc32 ( const std::vector& item ) { checksum = 0xFFFFFFFF; add(item); } // ---------------------------------------------------------------------------------------- crc32:: ~crc32 ( ) { } // ---------------------------------------------------------------------------------------- void crc32:: clear( ) { checksum = 0xFFFFFFFF; } // ---------------------------------------------------------------------------------------- void crc32:: add ( unsigned char item ) { checksum = (checksum>>8) ^ table((checksum^item) & 0xFF); } // ---------------------------------------------------------------------------------------- void crc32:: add ( const std::string& item ) { for (std::string::size_type i = 0; i < item.size(); ++i) checksum = (checksum>>8) ^ table((checksum^item[i]) & 0xFF); } // ---------------------------------------------------------------------------------------- void crc32:: add ( const std::vector& item ) { for (unsigned long i = 0; i < item.size(); ++i) checksum = (checksum>>8) ^ table((checksum^item[i]) & 0xFF); } // ---------------------------------------------------------------------------------------- unsigned long crc32:: get_checksum ( ) const { return checksum ^ 0xFFFFFFFF; } // ---------------------------------------------------------------------------------------- void crc32:: swap ( crc32& item ) { exchange(checksum,item.checksum); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CRC32_KERNEl_1_ ================================================ FILE: dlib/crc32/crc32_kernel_abstract.h ================================================ // Copyright (C) 2005 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_CRC32_KERNEl_ABSTRACT_ #ifdef DLIB_CRC32_KERNEl_ABSTRACT_ #include "../algs.h" #include #include namespace dlib { class crc32 { /*! INITIAL VALUE The current checksum covers zero bytes. get_checksum() == 0x00000000 WHAT THIS OBJECT REPRESENTS This object represents the CRC32 algorithm for calculating checksums. !*/ public: crc32 ( ); /*! ensures - #*this is properly initialized !*/ crc32 ( const std::string& item ); /*! ensures - #*this is properly initialized - calls this->add(item). (i.e. Using this constructor is the same as using the default constructor and then calling add() on item) !*/ crc32 ( const std::vector& item ); /*! ensures - #*this is properly initialized - calls this->add(item). (i.e. Using this constructor is the same as using the default constructor and then calling add() on item) !*/ virtual ~crc32 ( ); /*! ensures - any resources associated with *this have been released !*/ void clear( ); /*! ensures - #*this has its initial value !*/ void add ( unsigned char item ); /*! ensures - #get_checksum() == The checksum of all items added to *this previously concatenated with item. !*/ void add ( const std::string& item ); /*! ensures - #get_checksum() == The checksum of all items added to *this previously concatenated with item. !*/ void add ( const std::vector& item ); /*! ensures - #get_checksum() == The checksum of all items added to *this previously concatenated with item. !*/ unsigned long get_checksum ( ) const; /*! ensures - returns the current checksum !*/ operator unsigned long ( ) const; /*! ensures - returns get_checksum() !*/ void swap ( crc32& item ); /*! ensures - swaps *this and item !*/ }; void swap ( crc32& a, crc32& b ) { a.swap(b); } /*! provides a global swap function !*/ } #endif // DLIB_CRC32_KERNEl_ABSTRACT_ ================================================ FILE: dlib/crc32.h ================================================ // Copyright (C) 2005 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CRc32_ #define DLIB_CRc32_ #include "crc32/crc32_kernel_1.h" #endif // DLIB_CRc32_ ================================================ FILE: dlib/cstring ================================================ #include "dlib_include_path_tutorial.txt" ================================================ FILE: dlib/cuda/cpu_dlib.cpp ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CPU_cPP_ #define DLIB_DNN_CPU_cPP_ // This file contains CPU implementations of the GPU based functions in cuda_dlib.h #include "cpu_dlib.h" #include "tensor_tools.h" #include "../image_transforms/interpolation.h" #include "../threads.h" namespace dlib { namespace cpu { // ----------------------------------------------------------------------------------- void multiply ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ) { DLIB_CASSERT(dest.k() == src1.k() && src1.k() == src2.k() && dest.nr() == src1.nr() && src1.nr() == src2.nr() && dest.nc() == src1.nc() && src1.nc() == src2.nc() ); const long MD = std::max(std::max(dest.num_samples(),src1.num_samples()),src2.num_samples()); DLIB_CASSERT((dest.num_samples()==1 || dest.num_samples()==MD) && (src1.num_samples()==1 || src1.num_samples()==MD) && (src2.num_samples()==1 || src2.num_samples()==MD) ); if (dest.size() == 0) return; const size_t max_size = std::max(std::max(dest.size(),src1.size()),src2.size()); const auto d = dest.host(); const auto s1 = src1.host(); const auto s2 = src2.host(); if (dest.size() == src1.size() && src1.size() == src2.size()) { if (add_to) { for (size_t i = 0; i < src1.size(); ++i) d[i] += s1[i]*s2[i]; } else { for (size_t i = 0; i < src1.size(); ++i) d[i] = s1[i]*s2[i]; } } else if (dest.num_samples() == 1) { if (!add_to) { for (size_t i = 0; i < dest.size(); ++i) d[i] = 0; } for (size_t i = 0; i < max_size; ++i) d[i%dest.size()] += s1[i%src1.size()]*s2[i%src2.size()]; } else { if (add_to) { for (size_t i = 0; i < max_size; ++i) d[i] += s1[i%src1.size()]*s2[i%src2.size()]; } else { for (size_t i = 0; i < max_size; ++i) d[i] = s1[i%src1.size()]*s2[i%src2.size()]; } } } // ------------------------------------------------------------------------------------ void multiply_conv ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ) { auto d = dest.host(); auto s1 = src1.host(); auto s2 = src2.host(); if (have_same_dimensions(dest,src1)) { DLIB_CASSERT(src2.num_samples() == 1 && src2.nr() == 1 && src2.nc() == 1 && src2.k() == src1.k()); if (add_to) { for (long n = 0; n < dest.num_samples(); ++n) { for (long k = 0; k < dest.k(); ++k) { for (long r = 0; r < dest.nr(); ++r) { for (long c = 0; c < dest.nc(); ++c) { *d++ += (*s1++)*s2[k]; } } } } } else { for (long n = 0; n < dest.num_samples(); ++n) { for (long k = 0; k < dest.k(); ++k) { for (long r = 0; r < dest.nr(); ++r) { for (long c = 0; c < dest.nc(); ++c) { *d++ = (*s1++)*s2[k]; } } } } } } else { DLIB_CASSERT(have_same_dimensions(src1,src2)); DLIB_CASSERT(dest.num_samples() == 1 && dest.nr() == 1 && dest.nc() == 1 && dest.k() == src1.k()); if (!add_to) { for (long k = 0; k < src1.k(); ++k) d[k] = 0; } for (long n = 0; n < src1.num_samples(); ++n) { for (long k = 0; k < src1.k(); ++k) { for (long r = 0; r < src1.nr(); ++r) { for (long c = 0; c < src1.nc(); ++c) { d[k] += (*s1++)*(*s2++); } } } } } } // ------------------------------------------------------------------------------------ void scale_channels ( bool add_to, tensor& dest, const tensor& src, const tensor& scales ) { DLIB_CASSERT(have_same_dimensions(dest,src) && scales.num_samples() == src.num_samples() && scales.k() == src.k() && scales.nr() == 1 && scales.nc() == 1 ); if (dest.size() == 0) return; if (add_to) { auto d = dest.host(); auto s = src.host(); auto scal = scales.host(); for (long n = 0; n < src.num_samples(); ++n) { for (long k = 0; k < src.k(); ++k) { const auto scale = scal[n*scales.k() + k]; for (long r = 0; r < src.nr(); ++r) { for (long c = 0; c < src.nc(); ++c) { *d++ += (*s++) * scale; } } } } } else { auto d = dest.host_write_only(); auto s = src.host(); auto scal = scales.host(); for (long n = 0; n < src.num_samples(); ++n) { for (long k = 0; k < src.k(); ++k) { const auto scale = scal[n*scales.k() + k]; for (long r = 0; r < src.nr(); ++r) { for (long c = 0; c < src.nc(); ++c) { *d++ = (*s++) * scale; } } } } } } // ------------------------------------------------------------------------------------ void add( float beta, tensor& dest, float alpha, const tensor& src ) { DLIB_CASSERT( (have_same_dimensions(src, dest) || (src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1) || (src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()) || (src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()) || (src.num_samples()==dest.num_samples() && src.k()==1 && src.nr()==1 && src.nc()==1)) && is_same_object(src,dest) == false , "\n\t dest.num_samples(): " << dest.num_samples() <<"\n\t dest.k(): " << dest.k() <<"\n\t dest.nr(): " << dest.nr() <<"\n\t dest.nc(): " << dest.nc() <<"\n\t src.num_samples(): " << src.num_samples() <<"\n\t src.k(): " << src.k() <<"\n\t src.nr(): " << src.nr() <<"\n\t src.nc(): " << src.nc() ); if (beta == 0 && alpha == 0) { dest = 0; return; } auto d = dest.host(); auto s = src.host(); for (long n = 0; n < dest.num_samples(); ++n) { const auto sn = src.num_samples()==1 ? 0:n; for (long k = 0; k < dest.k(); ++k) { const auto sk = src.k()==1 ? 0:k; for (long r = 0; r < dest.nr(); ++r) { const auto sr = src.nr()==1 ? 0:r; for (long c = 0; c < dest.nc(); ++c) { const auto sc = src.nc()==1 ? 0:c; const auto s_idx = ((sn*src.k() + sk)*src.nr() + sr)*src.nc() + sc; *d = beta*(*d) + alpha*s[s_idx]; ++d; } } } } } // ---------------------------------------------------------------------------------------- void add ( tensor& dest, const tensor& src1, const tensor& src2 ) { auto d = dest.host(); auto s1 = src1.host(); auto s2 = src2.host(); // Do the simple and fast version if everything has the same dimensions if (have_same_dimensions(dest, src1) && have_same_dimensions(dest, src2)) { for (size_t i = 0; i < dest.size(); ++i) d[i] = s1[i] + s2[i]; return; } // Otherwise, do the more complex version with bounds checking. for (long n = 0; n < dest.num_samples(); ++n) { for (long k = 0; k < dest.k(); ++k) { for (long r = 0; r < dest.nr(); ++r) { for (long c = 0; c < dest.nc(); ++c) { float v1 = 0; float v2 = 0; // if this index is inside src1 if (n < src1.num_samples() && k < src1.k() && r < src1.nr() && c < src1.nc() ) { const auto s_idx = ((n*src1.k() + k)*src1.nr() + r)*src1.nc() + c; v1 = s1[s_idx]; } // if this index is inside src2 if (n < src2.num_samples() && k < src2.k() && r < src2.nr() && c < src2.nc() ) { const auto s_idx = ((n*src2.k() + k)*src2.nr() + r)*src2.nc() + c; v2 = s2[s_idx]; } *d = v1 + v2; ++d; } } } } } // ---------------------------------------------------------------------------------------- void multiply_zero_padded ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ) { auto d = dest.host(); auto s1 = src1.host(); auto s2 = src2.host(); // Do the simple and fast version if everything has the same dimensions if (have_same_dimensions(dest, src1) && have_same_dimensions(dest, src2)) { if (add_to) { for (size_t i = 0; i < dest.size(); ++i) d[i] += s1[i] * s2[i]; } else { for (size_t i = 0; i < dest.size(); ++i) d[i] = s1[i] * s2[i]; } return; } // Otherwise, do the more complex version with bounds checking. for (long n = 0; n < dest.num_samples(); ++n) { for (long k = 0; k < dest.k(); ++k) { for (long r = 0; r < dest.nr(); ++r) { for (long c = 0; c < dest.nc(); ++c) { float v1 = 0; float v2 = 0; // if this index is inside src1 if (n < src1.num_samples() && k < src1.k() && r < src1.nr() && c < src1.nc() ) { const auto s_idx = ((n*src1.k() + k)*src1.nr() + r)*src1.nc() + c; v1 = s1[s_idx]; } // if this index is inside src2 if (n < src2.num_samples() && k < src2.k() && r < src2.nr() && c < src2.nc() ) { const auto s_idx = ((n*src2.k() + k)*src2.nr() + r)*src2.nc() + c; v2 = s2[s_idx]; } if (add_to) *d += v1 * v2; else *d = v1 * v2; ++d; } } } } } // ---------------------------------------------------------------------------------------- void assign_bias_gradient ( tensor& grad, const tensor& gradient_input ) { DLIB_CASSERT( grad.num_samples() == 1 && gradient_input.k() == grad.k() && gradient_input.nr() == grad.nr() && gradient_input.nc() == grad.nc() && gradient_input.size() > 0); auto out = grad.host(); auto in = gradient_input.host(); for (size_t i = 0; i < grad.size(); ++i) out[i] = *in++; for (long j = 1; j < gradient_input.num_samples(); ++j) { for (size_t i = 0; i < grad.size(); ++i) out[i] += *in++; } } // ------------------------------------------------------------------------------------ void assign_conv_bias_gradient ( tensor& grad, const tensor& gradient_input ) { DLIB_CASSERT( grad.num_samples() == 1 && grad.k() >= 1 && grad.nr() == 1 && grad.nc() == 1 && gradient_input.k() == grad.k() && gradient_input.size() > 0 && is_same_object(grad,gradient_input) == false ); auto g = grad.host(); auto gi = gradient_input.host(); for (long k = 0; k < gradient_input.k(); ++k) g[k] = 0; for (long n = 0; n < gradient_input.num_samples(); ++n) { for (long k = 0; k < gradient_input.k(); ++k) { for (long r = 0; r < gradient_input.nr(); ++r) { for (long c = 0; c < gradient_input.nc(); ++c) { g[k] += (*gi++); } } } } } // ----------------------------------------------------------------------------------- void affine_transform( tensor& dest, const tensor& src, const float A, const float B ) { DLIB_CASSERT(dest.size()==src.size()); const auto d = dest.host(); const auto s = src.host(); for (size_t i = 0; i < src.size(); ++i) d[i] = A*s[i] + B; } void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const float A, const float B, const float C ) { DLIB_CASSERT(dest.size()==src1.size()); DLIB_CASSERT(dest.size()==src2.size()); const auto d = dest.host(); const auto s1 = src1.host(); const auto s2 = src2.host(); for (size_t i = 0; i < src1.size(); ++i) d[i] = A*s1[i] + B*s2[i] + C; } void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C, const float D ) { DLIB_CASSERT(dest.size()==src1.size()); DLIB_CASSERT(dest.size()==src2.size()); DLIB_CASSERT(dest.size()==src3.size()); const auto d = dest.host(); const auto s1 = src1.host(); const auto s2 = src2.host(); const auto s3 = src3.host(); for (size_t i = 0; i < src1.size(); ++i) d[i] = A*s1[i] + B*s2[i] + C*s3[i] + D; } void affine_transform_range( size_t begin, size_t end, tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C ) { DLIB_CASSERT(dest.size()==src1.size()); DLIB_CASSERT(dest.size()==src2.size()); DLIB_CASSERT(dest.size()==src3.size()); DLIB_CASSERT(begin <= end && end <= dest.size()); const auto d = dest.host(); const auto s1 = src1.host(); const auto s2 = src2.host(); const auto s3 = src3.host(); for (size_t i = begin; i < end; ++i) d[i] = A*s1[i] + B*s2[i] + C*s3[i]; } // ----------------------------------------------------------------------------------- void affine_transform( tensor& dest, const tensor& src, const tensor& A, const tensor& B ) { DLIB_CASSERT(have_same_dimensions(dest,src)); DLIB_CASSERT( ((A.num_samples()==1 && B.num_samples()==1) || (A.num_samples()==src.num_samples() && B.num_samples()==src.num_samples())) && A.nr()==B.nr() && B.nr()==src.nr() && A.nc()==B.nc() && B.nc()==src.nc() && A.k() ==B.k() && B.k()==src.k()); auto d = dest.host(); auto s = src.host(); const auto a = A.host(); const auto b = B.host(); if (A.num_samples() == 1) { const long num = src.size()/src.num_samples(); for (long i = 0; i < src.num_samples(); ++i) { for (long j = 0; j < num; ++j) { *d = a[j]*(*s) + b[j]; d++; s++; } } } else { for (size_t i = 0; i < src.size(); ++i) d[i] = a[i]*s[i] + b[i]; } } // ----------------------------------------------------------------------------------- void affine_transform_conv( tensor& dest, const tensor& src, const tensor& A, const tensor& B ) { DLIB_CASSERT(have_same_dimensions(dest,src)); DLIB_CASSERT(have_same_dimensions(A,B)); DLIB_CASSERT(A.num_samples() == 1 && A.nr() == 1 && A.nc() == 1 && A.k() == src.k()); auto d = dest.host(); auto s = src.host(); const auto a = A.host(); const auto b = B.host(); for (long n = 0; n < dest.num_samples(); ++n) { for (long k = 0; k < dest.k(); ++k) { for (long r = 0; r < dest.nr(); ++r) { for (long c = 0; c < dest.nc(); ++c) { *d++ = a[k]*(*s++) + b[k]; } } } } } // ---------------------------------------------------------------------------------------- void affine_transform( const rectangle& rect, tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, float A, float B, float C ) { DLIB_CASSERT(dest.size() == src1.size()); DLIB_CASSERT(dest.size() == src2.size()); DLIB_CASSERT(dest.size() == src3.size()); DLIB_CASSERT(dest.num_samples() == src1.num_samples()); DLIB_CASSERT(dest.num_samples() == src2.num_samples()); DLIB_CASSERT(dest.num_samples() == src3.num_samples()); DLIB_CASSERT(rectangle(0,0, dest.size()/dest.num_samples()-1, dest.num_samples()-1).contains(rect)); auto d = dest.host(); auto s1 = src1.host(); auto s2 = src2.host(); auto s3 = src3.host(); const auto nc = dest.size()/dest.num_samples(); for (long r = rect.top(); r <= rect.bottom(); ++r) { for (long c = rect.left(); c <= rect.right(); ++c) { auto idx = r*nc + c; d[idx] = s1[idx]*A + s2[idx]*B + s3[idx]*C; } } } // ----------------------------------------------------------------------------------- void compute_adam_update ( size_t begin, size_t end, tensor& s, tensor& m, tensor& v, const float t, const float learning_rate, const float weight_decay, const float momentum1, const float momentum2, const tensor& params, const tensor& params_grad ) { DLIB_CASSERT(s.size() == m.size() && s.size() == v.size() && s.size() == params.size() && s.size() == params_grad.size()); DLIB_CASSERT(begin <= end && end <= params.size()); const float eps = 1e-8; const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t)); // The loop is equivalent to doing this: // m = momentum1*m + (1-momentum1) * (weight_decay*params + params_grad); // v = momentum2*v + (1-momentum2)*squared(weight_decay*params + params_grad); // s = -alpha*m/(sqrt(v) + eps); auto pm = m.host(); auto pv = v.host(); auto ps = s.host_write_only(); auto pparams = params.host(); auto ppgrad = params_grad.host(); for (size_t i = begin; i < end; ++i) { float g = weight_decay*pparams[i] + ppgrad[i]; pm[i] = momentum1*pm[i] + (1-momentum1)*g; pv[i] = momentum2*pv[i] + (1-momentum2)*g*g; ps[i] = -alpha*pm[i]/(std::sqrt(pv[i]) + eps); } } // ----------------------------------------------------------------------------------- void batch_normalize_inference ( const double eps, resizable_tensor& dest, const tensor& src, const tensor& gamma, const tensor& beta, const tensor& running_means, const tensor& running_variances ) { DLIB_CASSERT( gamma.num_samples() == 1 && gamma.nr() == src.nr() && gamma.nc() == src.nc() && gamma.k() == src.k() && have_same_dimensions(gamma, beta) && have_same_dimensions(gamma, running_means) && have_same_dimensions(gamma, running_variances) && eps > 0, "\ngamma.num_samples(): " << gamma.num_samples() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\nbeta.num_samples(): " << beta.num_samples() << "\nbeta.k(): " << beta.k() << "\nbeta.nr(): " << beta.nr() << "\nbeta.nc(): " << beta.nc() << "\nrunning_means.num_samples(): " << running_means.num_samples() << "\nrunning_means.k(): " << running_means.k() << "\nrunning_means.nr(): " << running_means.nr() << "\nrunning_means.nc(): " << running_means.nc() << "\nrunning_variances.num_samples(): " << running_variances.num_samples() << "\nrunning_variances.k(): " << running_variances.k() << "\nrunning_variances.nr(): " << running_variances.nr() << "\nrunning_variances.nc(): " << running_variances.nc() << "\nsrc.k(): " << src.k() << "\nsrc.nr(): " << src.nr() << "\nsrc.nc(): " << src.nc() << "\neps: " << eps ); dest.copy_size(src); auto d = dest.host(); auto s = src.host(); auto g = gamma.host(); auto b = beta.host(); auto m = running_means.host(); auto v = running_variances.host(); const long num = src.k()*src.nr()*src.nc(); for (long n = 0; n < src.num_samples(); ++n) { for (long k = 0; k < num; ++k) { *d = g[k]*(*s - m[k])/std::sqrt(v[k]+eps) + b[k]; ++d; ++s; } } } void batch_normalize ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const double averaging_factor, resizable_tensor& running_means, resizable_tensor& running_variances, const tensor& src, const tensor& gamma, const tensor& beta ) { DLIB_CASSERT(0 <= averaging_factor && averaging_factor <= 1, "averaging_factor: " << averaging_factor); DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_means,means)); DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_variances,invstds)); DLIB_CASSERT( src.num_samples() > 1 && gamma.num_samples() == 1 && beta.num_samples() == 1 && gamma.nr() == beta.nr() && beta.nr() == src.nr() && gamma.nc() == beta.nc() && beta.nc() == src.nc() && gamma.k() == beta.k() && beta.k() == src.k() && eps > 0, "\ngamma.num_samples(): " << gamma.num_samples() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\nbeta.num_samples(): " << beta.num_samples() << "\nbeta.k(): " << beta.k() << "\nbeta.nr(): " << beta.nr() << "\nbeta.nc(): " << beta.nc() << "\nsrc.k(): " << src.k() << "\nsrc.nr(): " << src.nr() << "\nsrc.nc(): " << src.nc() << "\neps: " << eps ); dest.copy_size(src); means.set_size(1, src.k(), src.nr(), src.nc()); invstds.set_size(1, src.k(), src.nr(), src.nc()); running_means.set_size(1, src.k(), src.nr(), src.nc()); running_variances.set_size(1, src.k(), src.nr(), src.nc()); // first compute means and invstds const auto p_invstds = invstds.host(); const auto p_means = means.host(); auto p_src = src.host(); const auto rvar = running_variances.host(); const long num = src.k()*src.nr()*src.nc(); // This scale makes the running variances unbiased. const double scale = (src.num_samples())/(src.num_samples()-1.0); // Apply Welford's algorithm to improve numerical stability for (long i = 0; i < num; ++i) { double mean = 0.0; double M2 = 0.0; for (long n = 0; n < src.num_samples(); ++n) { float val = p_src[n*num+i]; const double delta1 = val - mean; mean += delta1 / (n + 1); const double delta2 = val - mean; M2 += delta1 * delta2; } p_means[i] = mean; const auto actual_var = (src.num_samples() > 1) ? (M2 / src.num_samples()) : 0.0; if (averaging_factor == 1) rvar[i] = scale*actual_var; else rvar[i] = (1-averaging_factor)*rvar[i] + scale*averaging_factor*actual_var; p_invstds[i] = 1.0f/std::sqrt(actual_var + eps); } auto p_dest = dest.host(); const auto p_gamma = gamma.host(); const auto p_beta = beta.host(); for (long n = 0; n < src.num_samples(); ++n) { for (long i = 0; i < num; ++i) { *p_dest = (*p_src - p_means[i])*p_invstds[i]; *p_dest = (*p_dest)*p_gamma[i] + p_beta[i]; ++p_src; ++p_dest; } } // now keep track of the running means if (averaging_factor != 1) running_means = (1-averaging_factor)*mat(running_means) + averaging_factor*mat(means); else running_means = means; } void batch_normalize_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad ) { const long num = src.k()*src.nr()*src.nc(); DLIB_CASSERT(src.num_samples() > 1); DLIB_CASSERT(num == (long)means.size()); DLIB_CASSERT(num == (long)invstds.size()); DLIB_CASSERT(num == (long)gamma.size()); DLIB_CASSERT(num == (long)gamma_grad.size()); DLIB_CASSERT(num == (long)beta_grad.size()); DLIB_CASSERT(have_same_dimensions(gradient_input, src)); DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); DLIB_CASSERT(eps > 0); beta_grad = 0; gamma_grad = 0; auto p_grad = gradient_input.host(); auto p_src = src.host(); const auto p_gamma = gamma.host(); const auto p_gamma_grad = gamma_grad.host(); const auto p_beta_grad = beta_grad.host(); const auto p_invstds = invstds.host(); const auto p_means = means.host(); resizable_tensor dvars, dmeans; dvars.copy_size(invstds); dmeans.copy_size(means); dvars = 0; dmeans = 0; const auto p_dvars = dvars.host(); const auto p_dmeans = dmeans.host(); for (long n = 0; n < src.num_samples(); ++n) { for (long i = 0; i < num; ++i) { const float x_hat = (*p_src - p_means[i])*p_invstds[i]; p_beta_grad[i] += *p_grad; p_gamma_grad[i] += (*p_grad)*x_hat; const float dx = *p_grad * p_gamma[i]; p_dvars[i] += dx*(*p_src - p_means[i])*-0.5*std::pow(p_invstds[i], 3.0f); ++p_grad; ++p_src; } } const float invnum = 1.0f/src.num_samples(); p_grad = gradient_input.host(); p_src = src.host(); for (long n = 0; n < src.num_samples(); ++n) { for (long i = 0; i < num; ++i) { const float dx = *p_grad * p_gamma[i]; p_dmeans[i] += dx*-p_invstds[i] + p_dvars[i] * -2*(*p_src - p_means[i])*invnum; ++p_grad; ++p_src; } } p_grad = gradient_input.host(); p_src = src.host(); auto p_src_grad = src_grad.host(); for (long n = 0; n < src.num_samples(); ++n) { for (long i = 0; i < num; ++i) { const float dx = *p_grad * p_gamma[i]; *p_src_grad += dx*p_invstds[i] + p_dvars[i] *2*(*p_src - p_means[i])*invnum + p_dmeans[i]*invnum; ++p_grad; ++p_src; ++p_src_grad; } } } // ---------------------------------------------------------------------------------------- void batch_normalize_conv_inference ( const double eps, resizable_tensor& dest, const tensor& src, const tensor& gamma, const tensor& beta, const tensor& running_means, const tensor& running_variances ) { DLIB_CASSERT( gamma.num_samples() == 1 && gamma.nr() == 1 && gamma.nc() == 1 && gamma.k() == src.k() && have_same_dimensions(gamma, beta) && have_same_dimensions(gamma, running_means) && have_same_dimensions(gamma, running_variances) && eps > 0, "\ngamma.num_samples(): " << gamma.num_samples() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\nbeta.num_samples(): " << beta.num_samples() << "\nbeta.k(): " << beta.k() << "\nbeta.nr(): " << beta.nr() << "\nbeta.nc(): " << beta.nc() << "\nrunning_means.num_samples(): " << running_means.num_samples() << "\nrunning_means.k(): " << running_means.k() << "\nrunning_means.nr(): " << running_means.nr() << "\nrunning_means.nc(): " << running_means.nc() << "\nrunning_variances.num_samples(): " << running_variances.num_samples() << "\nrunning_variances.k(): " << running_variances.k() << "\nrunning_variances.nr(): " << running_variances.nr() << "\nrunning_variances.nc(): " << running_variances.nc() << "\nsrc.k(): " << src.k() << "\nsrc.nr(): " << src.nr() << "\nsrc.nc(): " << src.nc() << "\neps: " << eps ); dest.copy_size(src); auto d = dest.host(); auto s = src.host(); auto g = gamma.host(); auto b = beta.host(); auto m = running_means.host(); auto v = running_variances.host(); const long num = src.nr()*src.nc(); for (long n = 0; n < src.num_samples(); ++n) { for (long k = 0; k < src.k(); ++k) { const float invstd = 1.0f/std::sqrt(v[k] + eps); for (long j = 0; j < num; ++j) { *d = g[k]*(*s - m[k])*invstd + b[k]; ++d; ++s; } } } } void batch_normalize_conv ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const double averaging_factor, resizable_tensor& running_means, resizable_tensor& running_variances, const tensor& src, const tensor& gamma, const tensor& beta ) { DLIB_CASSERT(0 <= averaging_factor && averaging_factor <= 1, "averaging_factor: " << averaging_factor); DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_means,means)); DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_variances,invstds)); DLIB_CASSERT( src.num_samples() > 1 && gamma.num_samples() == 1 && beta.num_samples() == 1 && gamma.nr() == 1 && beta.nr() == 1 && gamma.nc() == 1 && beta.nc() == 1 && gamma.k() == beta.k() && beta.k() == src.k() && eps > 0, "\ngamma.num_samples(): " << gamma.num_samples() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\nbeta.num_samples(): " << beta.num_samples() << "\nbeta.k(): " << beta.k() << "\nbeta.nr(): " << beta.nr() << "\nbeta.nc(): " << beta.nc() << "\nsrc.k(): " << src.k() << "\nsrc.nr(): " << src.nr() << "\nsrc.nc(): " << src.nc() << "\neps: " << eps ); dest.copy_size(src); means.set_size(1, src.k()); invstds.set_size(1, src.k()); running_means.set_size(1, src.k()); running_variances.set_size(1, src.k()); // first compute means and invstds const auto p_invstds = invstds.host(); const auto p_means = means.host(); const auto p_gamma = gamma.host(); const auto p_beta = beta.host(); auto p_src = src.host(); auto rvar = running_variances.host(); const long num = src.nr()*src.nc(); // This scale makes the running variances unbiased. const double scale = (src.num_samples()*num)/(src.num_samples()*num-1.0); // Apply Welford's algorithm to improve numerical stability for (long k = 0; k < src.k(); ++k) { double mean = 0.0; double M2 = 0.0; long count = 0; for (long n = 0; n < src.num_samples(); ++n) { long start_index = tensor_index(src, n, k, 0, 0); auto p = p_src + start_index; for (long i = 0; i < num; ++i) { const float val = *p; const double delta1 = val - mean; mean += delta1 / (count + 1); const double delta2 = val - mean; M2 += delta1 * delta2; ++count; ++p; } } const auto actual_var = (count > 1) ? (M2 / count) : 0.0; if (averaging_factor == 1) rvar[k] = scale*actual_var; else rvar[k] = (1-averaging_factor)*rvar[k] + scale*averaging_factor*actual_var; p_means[k] = mean; p_invstds[k] = 1.0f/std::sqrt(actual_var + eps); } auto p_dest = dest.host(); for (long n = 0; n < src.num_samples(); ++n) { for (long k = 0; k < src.k(); ++k) { for (long i = 0; i < num; ++i) { *p_dest = (*p_src - p_means[k])*p_invstds[k]; *p_dest = (*p_dest)*p_gamma[k] + p_beta[k]; ++p_src; ++p_dest; } } } // now keep track of the running means if (averaging_factor != 1) running_means = (1-averaging_factor)*mat(running_means) + averaging_factor*mat(means); else running_means = means; } void batch_normalize_conv_gradient( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad ) { const long num = src.nr()*src.nc(); DLIB_CASSERT(src.num_samples() > 1); DLIB_CASSERT(src.k() == (long)means.size()); DLIB_CASSERT(src.k() == (long)invstds.size()); DLIB_CASSERT(src.k() == (long)gamma.size()); DLIB_CASSERT(src.k() == (long)gamma_grad.size()); DLIB_CASSERT(src.k() == (long)beta_grad.size()); DLIB_CASSERT(have_same_dimensions(gradient_input, src)); DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); DLIB_CASSERT(eps > 0); beta_grad = 0; gamma_grad = 0; auto p_grad = gradient_input.host(); auto p_src = src.host(); const auto p_gamma = gamma.host(); const auto p_gamma_grad = gamma_grad.host(); const auto p_beta_grad = beta_grad.host(); const auto p_invstds = invstds.host(); const auto p_means = means.host(); resizable_tensor dvars, dmeans; dvars.copy_size(invstds); dmeans.copy_size(means); dvars = 0; dmeans = 0; const auto p_dvars = dvars.host(); const auto p_dmeans = dmeans.host(); for (long n = 0; n < src.num_samples(); ++n) { for (long k = 0; k < src.k(); ++k) { const float invstd_pow = -0.5*std::pow(p_invstds[k], 3.0f); for (long i = 0; i < num; ++i) { const float x_hat = (*p_src - p_means[k])*p_invstds[k]; p_beta_grad[k] += *p_grad; p_gamma_grad[k] += (*p_grad)*x_hat; const float dx = *p_grad * p_gamma[k]; p_dvars[k] += dx*(*p_src - p_means[k])*invstd_pow; ++p_grad; ++p_src; } } } p_grad = gradient_input.host(); p_src = src.host(); const float invnum = 1.0f/(src.num_samples()*num); for (long n = 0; n < src.num_samples(); ++n) { for (long k = 0; k < src.k(); ++k) { for (long i = 0; i < num; ++i) { const float dx = *p_grad * p_gamma[k]; p_dmeans[k] += -dx*p_invstds[k] + p_dvars[k] * -2*(*p_src - p_means[k])*invnum; ++p_grad; ++p_src; } } } p_grad = gradient_input.host(); p_src = src.host(); auto p_src_grad = src_grad.host(); for (long n = 0; n < src.num_samples(); ++n) { for (long k = 0; k < src.k(); ++k) { for (long i = 0; i < num; ++i) { const float dx = *p_grad * p_gamma[k]; *p_src_grad += dx*p_invstds[k] + p_dvars[k]*2*(*p_src - p_means[k])*invnum + p_dmeans[k]*invnum; ++p_grad; ++p_src; ++p_src_grad; } } } } // ----------------------------------------------------------------------------------- void layer_normalize ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const tensor& src, const tensor& gamma, const tensor& beta ) { DLIB_CASSERT( have_same_dimensions(gamma, beta) && gamma.k() == src.k() && gamma.nr() == 1 && gamma.nc() == 1 && eps > 0, "\nsrc.k(): " << src.k() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\nbeta.k(): " << beta.k() << "\nbeta.nr(): " << beta.nr() << "\nbeta.nc(): " << beta.nc() << "\neps: " << eps ); dest.copy_size(src); means.set_size(src.num_samples()); invstds.set_size(src.num_samples()); // first compute means and invstds means = 0; invstds = 0; const float* p_src = src.host(); float* p_invstds = invstds.host(); float* p_means = means.host(); const long num = src.nr() * src.nc(); // compute means, and sum of squares for (long n = 0; n < src.num_samples(); ++n) { for (long k = 0; k < src.k(); ++k) { for (long i = 0; i < num; ++i) { p_means[n] += *p_src; p_invstds[n] += (*p_src) * (*p_src); ++p_src; } } } means /= src.k() * num; invstds /= src.k () * num; // copy data back to host invstds.host(); means.host(); // compute variances for (long n = 0; n < src.num_samples(); ++n) { p_invstds[n] = 1.0f / std::sqrt(p_invstds[n] - p_means[n] * p_means[n] + eps); } p_src = src.host(); float* p_dest = dest.host(); const float* p_gamma = gamma.host(); const float* p_beta = beta.host(); for (long n = 0; n < src.num_samples(); ++n) { for (long k = 0; k < src.k(); ++k) { for (long i = 0; i < num; ++i) { *p_dest = (*p_src - p_means[n]) * p_invstds[n]; *p_dest = (*p_dest) * p_gamma[k] + p_beta[k]; ++p_src; ++p_dest; } } } } void layer_normalize_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad, resizable_tensor& dmeans, resizable_tensor& dvars ) { const long num = src.nr() * src.nc(); DLIB_CASSERT(src.num_samples() == means.size()); DLIB_CASSERT(src.num_samples() == invstds.size()); DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad)); DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad)); DLIB_CASSERT(gamma.k() == src.k()); DLIB_CASSERT(gamma.nr() == 1); DLIB_CASSERT(gamma.nc() == 1); DLIB_CASSERT(have_same_dimensions(gradient_input, src)); DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); DLIB_CASSERT(eps > 0); beta_grad = 0; gamma_grad = 0; auto p_grad = gradient_input.host(); auto p_src = src.host(); const auto p_gamma = gamma.host(); const auto p_gamma_grad = gamma_grad.host(); const auto p_beta_grad = beta_grad.host(); const auto p_invstds = invstds.host(); const auto p_means = means.host(); dvars.copy_size(invstds); dmeans.copy_size(means); dvars = 0; dmeans = 0; const auto p_dvars = dvars.host(); const auto p_dmeans = dmeans.host(); for (long n = 0; n < src.num_samples(); ++n) { const float invstd_pow = -0.5 * std::pow(p_invstds[n], 3.0f); for (long k = 0; k < src.k(); ++k) { for (long i = 0; i < num; ++i) { const float x_hat = (*p_src - p_means[n]) * p_invstds[n]; p_beta_grad[k] += *p_grad; p_gamma_grad[k] += (*p_grad) * x_hat; const float dx = *p_grad * p_gamma[k]; p_dvars[n] += dx * (*p_src - p_means[n]) * invstd_pow; ++p_grad; ++p_src; } } } p_grad = gradient_input.host(); p_src = src.host(); const float invnum = 1.0f / (src.k() * num); for (long n = 0; n < src.num_samples(); ++n) { for (long k = 0; k < src.k(); ++k) { for (long i = 0; i < num; ++i) { const float dx = *p_grad * p_gamma[k]; p_dmeans[n] += -dx * p_invstds[n] + p_dvars[n] * -2 * (*p_src - p_means[n]) * invnum; ++p_grad; ++p_src; } } } p_grad = gradient_input.host(); p_src = src.host(); auto p_src_grad = src_grad.host(); for (long n = 0; n < src.num_samples(); ++n) { for (long k = 0; k < src.k(); ++k) { for (long i = 0; i < num; ++i) { const float dx = *p_grad * p_gamma[k]; *p_src_grad += dx * p_invstds[n] + p_dvars[n] * 2 * (*p_src - p_means[n]) * invnum + p_dmeans[n] * invnum; ++p_grad; ++p_src; ++p_src_grad; } } } } // ----------------------------------------------------------------------------------- void rms_normalize( const double eps, resizable_tensor& dest, resizable_tensor& scale, const tensor& src, const tensor& gamma ) { DLIB_CASSERT( gamma.k() == src.k() && gamma.nr() == 1 && gamma.nc() == 1 && eps > 0, "\nsrc.k(): " << src.k() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\neps: " << eps ); const long ns = src.num_samples(); const long ks = src.k(); const long num = src.nr() * src.nc(); dest.copy_size(src); scale.set_size(ns); // Compute RMS values scale = 0; const float* p_src = src.host(); float* p_scale = scale.host(); for (long n = 0; n < ns; ++n) { for (long k = 0; k < ks; ++k) { for (long i = 0; i < num; ++i) { p_scale[n] += (*p_src) * (*p_src); ++p_src; } } p_scale[n] = 1.0f / std::sqrt(p_scale[n] / (ks * num) + static_cast(eps)); } scale.host(); // Apply RMS normalization p_src = src.host(); float* p_dest = dest.host(); const float* p_gamma = gamma.host(); for (long n = 0; n < ns; ++n) { for (long k = 0; k < ks; ++k) { for (long i = 0; i < num; ++i) { *p_dest = (*p_src) * p_scale[n] * p_gamma[k]; ++p_src; ++p_dest; } } } } void rms_normalize_gradient( const tensor& gradient_input, const tensor& scale, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, resizable_tensor& dscale ) { DLIB_CASSERT(src.num_samples() == scale.size()); DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad)); DLIB_CASSERT(gamma.k() == src.k()); DLIB_CASSERT(gamma.nr() == 1); DLIB_CASSERT(gamma.nc() == 1); DLIB_CASSERT(have_same_dimensions(gradient_input, src)); DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); const long ns = src.num_samples(); const long ks = src.k(); const long num = src.nr() * src.nc(); gamma_grad = 0; dscale.copy_size(scale); dscale = 0; auto p_grad = gradient_input.host(); auto p_src = src.host(); const auto p_gamma = gamma.host(); const auto p_gamma_grad = gamma_grad.host(); const auto p_scale = scale.host(); auto p_dscale = dscale.host(); for (long n = 0; n < ns; ++n) { const float scale_pow = -0.5f * std::pow(p_scale[n], 3.0f); for (long k = 0; k < ks; ++k) { for (long i = 0; i < num; ++i) { const float x_hat = *p_src * p_scale[n]; p_gamma_grad[k] += (*p_grad) * x_hat; const float dx = *p_grad * p_gamma[k]; p_dscale[n] += dx * *p_src * scale_pow; ++p_grad; ++p_src; } } } p_grad = gradient_input.host(); p_src = src.host(); auto p_src_grad = src_grad.host(); const float invnum = 1.0f / (ks * num); for (long n = 0; n < ns; ++n) { for (long k = 0; k < ks; ++k) { for (long i = 0; i < num; ++i) { const float dx = *p_grad * p_gamma[k]; *p_src_grad += dx * p_scale[n] + p_dscale[n] * 2 * *p_src * invnum; ++p_grad; ++p_src; ++p_src_grad; } } } } // ----------------------------------------------------------------------------------- void threshold ( tensor& data, float thresh ) { const auto d = data.host(); for (size_t i = 0; i < data.size(); ++i) d[i] = d[i]>thresh ? 1:0; } void dot ( const tensor& a, const tensor& b, tensor& result, size_t idx ) { DLIB_CASSERT(a.size() == b.size()); DLIB_CASSERT(idx < result.size()); const auto aa = a.host(); const auto bb = b.host(); auto r = result.host(); for (size_t i = 0; i < a.size(); ++i) r[idx] += aa[i]*bb[i]; } // ----------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------- namespace ttimpl { void softmax( const long num_locations, const long num_channels, tensor& dest, const tensor& src, operation_mode mode = operation_mode::CHANNEL_WISE ) { DLIB_ASSERT(num_channels * num_locations == src.nr() * src.nc() * src.k()); DLIB_CASSERT(have_same_dimensions(dest, src)); const auto d = dest.host(); const auto s = src.host(); for (long n = 0; n < src.num_samples(); ++n) { auto ss = s + num_locations * num_channels * n; auto dd = d + num_locations * num_channels * n; if (mode == operation_mode::CHANNEL_WISE) { for (long i = 0; i < num_locations; ++i) { float max_val = -std::numeric_limits::infinity(); for (long k = 0; k < num_channels; ++k) max_val = std::max(max_val, ss[k * num_locations]); float sum = 0.0f; for (long k = 0; k < num_channels; ++k) { dd[k * num_locations] = std::exp(ss[k * num_locations] - max_val); sum += dd[k * num_locations]; } for (long k = 0; k < num_channels; ++k) dd[k * num_locations] /= sum; ++ss; ++dd; } } else if (mode == operation_mode::PLANE_WISE) { for (long k = 0; k < num_channels; ++k) { auto s_channel = ss + k * num_locations; auto d_channel = dd + k * num_locations; for (long r = 0; r < src.nr(); ++r) { float max_val = -std::numeric_limits::infinity(); for (long c = 0, idx = r * src.nc(); c < src.nc(); ++c, ++idx) max_val = std::max(max_val, s_channel[idx]); if (max_val == -std::numeric_limits::infinity()) { for (long c = 0, idx = r * src.nc(); c < src.nc(); ++c, ++idx) d_channel[idx] = 0.0f; } else { float sum = 0.0f; for (long c = 0, idx = r * src.nc(); c < src.nc(); ++c, ++idx) { d_channel[idx] = std::exp(s_channel[idx] - max_val); sum += d_channel[idx]; } for (long c = 0, idx = r * src.nc(); c < src.nc(); ++c, ++idx) d_channel[idx] /= sum; } } } } } } void softmax_gradient( const long num_locations, const long num_channels, tensor& grad, const tensor& dest, const tensor& gradient_input, operation_mode mode = operation_mode::CHANNEL_WISE ) { DLIB_ASSERT(num_channels * num_locations == grad.nr() * grad.nc() * grad.k()); DLIB_CASSERT(have_same_dimensions(grad, dest)); DLIB_CASSERT(have_same_dimensions(grad, gradient_input)); const auto d = dest.host(); const auto g = grad.host(); const auto in = gradient_input.host(); for (long n = 0; n < grad.num_samples(); ++n) { const auto d2 = d + num_locations * num_channels * n; const auto g2 = g + num_locations * num_channels * n; const auto in2 = in + num_locations * num_channels * n; if (mode == operation_mode::CHANNEL_WISE) { for (long i = 0; i < num_locations; ++i) { const auto d3 = d2 + i; const auto g3 = g2 + i; const auto in3 = in2 + i; float sum = 0.0f; for (long k = 0; k < num_channels; ++k) sum += -d3[k * num_locations] * in3[k * num_locations]; if (is_same_object(gradient_input, grad)) { for (long k = 0; k < num_channels; ++k) g3[k * num_locations] = d3[k * num_locations] * (sum + in3[k * num_locations]); } else { for (long k = 0; k < num_channels; ++k) g3[k * num_locations] += d3[k * num_locations] * (sum + in3[k * num_locations]); } } } else if (mode == operation_mode::PLANE_WISE) { for (long k = 0; k < num_channels; ++k) { const auto d_channel = d2 + k * num_locations; const auto g_channel = g2 + k * num_locations; const auto in_channel = in2 + k * num_locations; for (long r = 0; r < grad.nr(); ++r) { float sum = 0.0f; for (long c = 0, idx = r * grad.nc(); c < grad.nc(); ++c, ++idx) sum += -d_channel[idx] * in_channel[idx]; if (is_same_object(gradient_input, grad)) { for (long c = 0, idx = r * grad.nc(); c < grad.nc(); ++c, ++idx) g_channel[idx] = d_channel[idx] * (sum + in_channel[idx]); } else { for (long c = 0, idx = r * grad.nc(); c < grad.nc(); ++c, ++idx) g_channel[idx] += d_channel[idx] * (sum + in_channel[idx]); } } } } } } } // ---------------------------------------------------------------------------------------- void softmax( tensor& dest, const tensor& src, operation_mode mode ) { DLIB_CASSERT(have_same_dimensions(dest, src)); DLIB_CASSERT(mode == operation_mode::CHANNEL_WISE || mode == operation_mode::PLANE_WISE, "Invalid softmax mode"); ttimpl::softmax(src.nr() * src.nc(), src.k(), dest, src, mode); } void softmax_gradient( tensor& grad, const tensor& dest, const tensor& gradient_input, operation_mode mode ) { DLIB_CASSERT(have_same_dimensions(grad, dest)); DLIB_CASSERT(have_same_dimensions(grad, gradient_input)); ttimpl::softmax_gradient(grad.nr() * grad.nc(), grad.k(), grad, dest, gradient_input, mode); } // ------------------------------------------------------------------------------------ void softmax_all ( tensor& dest, const tensor& src ) { DLIB_CASSERT(have_same_dimensions(dest,src)); ttimpl::softmax(1, src.nr()*src.nc()*src.k(), dest, src); } void softmax_all_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ) { DLIB_CASSERT(have_same_dimensions(grad,dest)); DLIB_CASSERT(have_same_dimensions(grad,gradient_input)); ttimpl::softmax_gradient(1, grad.nr()*grad.nc()*grad.k(), grad, dest, gradient_input); } // ------------------------------------------------------------------------------------ void sigmoid ( tensor& dest, const tensor& src ) { const auto d = dest.host(); const auto s = src.host(); for (size_t i = 0; i < src.size(); ++i) d[i] = 1/(1+std::exp(-s[i])); } void sigmoid_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ) { const auto g = grad.host(); const auto d = dest.host(); const auto in = gradient_input.host(); if (is_same_object(gradient_input, grad)) { for (size_t i = 0; i < dest.size(); ++i) g[i] = in[i]*d[i]*(1-d[i]); } else { for (size_t i = 0; i < dest.size(); ++i) g[i] += in[i]*d[i]*(1-d[i]); } } // ------------------------------------------------------------------------------------ void mish ( tensor& dest, const tensor& src ) { const auto d = dest.host_write_only(); const auto s = src.host(); for (size_t i = 0; i < src.size(); ++i) { const auto e = std::exp(s[i]); const auto delta = 2*e + e*e + 2; d[i] = s[i] - 2*s[i]/delta; } } void mish_gradient( tensor& grad, const tensor& src, const tensor& gradient_input ) { const auto g = grad.host(); const auto s = src.host(); const auto in = gradient_input.host(); const auto calculate_gradient = [](float x) { if (x >= 8) return 1.f; if (x <= -8) return 0.f; const auto e = std::exp(x); const auto delta = 2*e + e*e + 2; const auto omega = 4*(x + 1) + 4*e*e + e*e*e + e*(4*x + 6); return e*omega/(delta*delta); }; if (is_same_object(gradient_input, grad)) { for (size_t i = 0; i < src.size(); ++i) g[i] = in[i]*calculate_gradient(s[i]); } else { for (size_t i = 0; i < src.size(); ++i) g[i] += in[i]*calculate_gradient(s[i]); } } // ------------------------------------------------------------------------------------ void relu ( tensor& dest, const tensor& src ) { dest = lowerbound(mat(src), 0); } void relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ) { const float* gi = gradient_input.host(); const float* in = dest.host(); float* out = grad.host(); if (is_same_object(grad, gradient_input)) { for (size_t i = 0; i < dest.size(); ++i) { if (in[i] > 0) out[i] = gi[i]; else out[i] = 0; } } else { for (size_t i = 0; i < dest.size(); ++i) { if (in[i] > 0) out[i] += gi[i]; } } } // ---------------------------------------------------------------------------------------- void prelu ( tensor& dest, const tensor& src, const tensor& param ) { const float p = param.host()[0]; const float* s = src.host(); float* d = dest.host(); for (size_t i = 0; i < dest.size(); ++i) { if (s[i] > 0) d[i] = s[i]; else d[i] = p*s[i]; } } void prelu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input, const tensor& param, tensor& params_grad ) { DLIB_CASSERT(is_same_object(grad, gradient_input) == false); const float p = param.host()[0]; const float* gi = gradient_input.host(); const float* s = src.host(); float* out = grad.host(); float pgrad = 0; for (size_t i = 0; i < src.size(); ++i) { if (s[i] > 0) { out[i] += gi[i]; } else { out[i] += p*gi[i]; pgrad += gi[i]*s[i]; } } params_grad.host()[0] = pgrad; } // ------------------------------------------------------------------------------------ void leaky_relu ( tensor& dest, const tensor& src, const float alpha ) { const float* s = src.host(); float* d = dest.host(); for (size_t i = 0; i < dest.size(); ++i) { if (s[i] > 0) d[i] = s[i]; else d[i] = alpha * s[i]; } } void leaky_relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float alpha ) { const float* gi = gradient_input.host(); const float* in = dest.host(); float* out = grad.host(); if (is_same_object(grad, gradient_input)) { for (size_t i = 0; i < dest.size(); ++i) { if (in[i] > 0) out[i] = gi[i]; else out[i] = alpha * gi[i]; } } else { for (size_t i = 0; i < dest.size(); ++i) { if (in[i] > 0) out[i] += gi[i]; else out[i] += alpha * gi[i]; } } } // ------------------------------------------------------------------------------------ void tanh ( tensor& dest, const tensor& src ) { const auto d = dest.host(); const auto s = src.host(); for (size_t i = 0; i < src.size(); ++i) d[i] = std::tanh(s[i]); } void tanh_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ) { const auto g = grad.host(); const auto d = dest.host(); const auto in = gradient_input.host(); if (is_same_object(grad, gradient_input)) { for (size_t i = 0; i < dest.size(); ++i) g[i] = in[i]*(1-d[i]*d[i]); } else { for (size_t i = 0; i < dest.size(); ++i) g[i] += in[i]*(1-d[i]*d[i]); } } // ---------------------------------------------------------------------------------------- void clipped_relu ( tensor& dest, const tensor& src, const float ceiling ) { dest = upperbound(lowerbound(mat(src), 0), ceiling); } void clipped_relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float ceiling ) { const auto out = grad.host(); const auto in = dest.host(); const auto gi = gradient_input.host(); if (is_same_object(grad, gradient_input)) { for (size_t i = 0; i < dest.size(); ++i) { if (in[i] > 0 && in[i] < ceiling) out[i] = gi[i]; else out[i] = 0; } } else { for (size_t i = 0; i < dest.size(); ++i) { if (in[i] > 0 && in[i] < ceiling) out[i] += gi[i]; } } } // ---------------------------------------------------------------------------------------- void elu ( tensor& dest, const tensor& src, const float alpha ) { const auto d = dest.host(); const auto s = src.host(); for (size_t i = 0; i < src.size(); ++i) { if (s[i] > 0) d[i] = s[i]; else d[i] = alpha * (std::exp(s[i]) - 1.0f); } } void elu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float alpha ) { const auto out = grad.host(); const auto in = dest.host(); const auto gi = gradient_input.host(); if (is_same_object(grad, gradient_input)) { for (size_t i = 0; i < dest.size(); ++i) { if (in[i] > 0) out[i] = gi[i]; else out[i] = (alpha + in[i]) * gi[i]; } } else { for (size_t i = 0; i < dest.size(); ++i) { if (in[i] > 0) out[i] += gi[i]; else out[i] += (alpha + in[i]) * gi[i]; } } } // ---------------------------------------------------------------------------------------- void gelu ( tensor& dest, const tensor& src ) { const auto d = dest.host(); const auto s = src.host(); for (size_t i = 0; i < src.size(); ++i) d[i] = 0.5f*s[i]*(1.0f + std::erf(s[i]/sqrt_2)); } void gelu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ) { const float beta = 1.0f / std::sqrt(2.0f * pi); const auto compute_gradient = [beta](float x) { const float cdf = 0.5f*(1.0f + std::erf(x/sqrt_2)); const float pdf = beta*std::exp(-0.5f*x*x); return cdf + x * pdf; }; const auto g = grad.host(); const auto s = src.host(); const auto in = gradient_input.host(); if (is_same_object(grad, gradient_input)) { for (size_t i = 0; i < src.size(); ++i) g[i] = in[i]*compute_gradient(s[i]); } else { for (size_t i = 0; i < src.size(); ++i) g[i] += in[i]*compute_gradient(s[i]); } } // ---------------------------------------------------------------------------------------- void smelu ( tensor& dest, const tensor& src, const float beta ) { const float* s = src.host(); float* d = dest.host(); for (size_t i = 0; i < dest.size(); ++i) { if (s[i] >= beta) d[i] = s[i]; else if (s[i] <= -beta) d[i] = 0; else d[i] = (s[i] + beta) * (s[i] + beta) / (4 * beta); } } void smelu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float beta ) { const float* gi = gradient_input.host(); const float* in = dest.host(); float* out = grad.host(); if (is_same_object(grad, gradient_input)) { for (size_t i = 0; i < dest.size(); ++i) { if (in[i] >= beta) out[i] = gi[i]; else if (in[i] == 0) out[i] = 0; else out[i] = std::sqrt(beta * in[i]) / beta * gi[i]; } } else { for (size_t i = 0; i < dest.size(); ++i) { if (in[i] >= beta) out[i] += gi[i]; else if (in[i] == 0) continue; else out[i] += std::sqrt(beta * in[i]) / beta * gi[i]; } } } // ---------------------------------------------------------------------------------------- void silu ( tensor& dest, const tensor& src ) { const auto d = dest.host(); const auto s = src.host(); for (size_t i = 0; i < src.size(); ++i) d[i] = s[i] * impl::sigmoid(s[i]); } void silu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ) { const auto g = grad.host(); const auto s = src.host(); const auto in = gradient_input.host(); if (is_same_object(grad, gradient_input)) { for (size_t i = 0; i < src.size(); ++i) { const auto sig_s = impl::sigmoid(s[i]); g[i] = in[i] * (sig_s * (1.0f + s[i] * (1.0f - sig_s))); } } else { for (size_t i = 0; i < src.size(); ++i) { const auto sig_s = impl::sigmoid(s[i]); g[i] += in[i] * (sig_s * (1.0f + s[i] * (1.0f - sig_s))); } } } // ---------------------------------------------------------------------------------------- void resize_bilinear ( tensor& dest, long long dest_row_stride, long long dest_channel_stride, const tensor& src, long long src_row_stride, long long src_channel_stride ) { DLIB_CASSERT(is_same_object(dest, src)==false); DLIB_CASSERT(dest.num_samples() == src.num_samples()); DLIB_CASSERT(dest.k() == src.k()); if (dest.size() == 0 || src.size() == 0) return; const float* s = src.host(); float* d = dest.host(); parallel_for(0, dest.k()*dest.num_samples(), [&](long i) { auto simg = sub_image(s+i*src_channel_stride, src.nr(), src.nc(), src_row_stride); auto dimg = sub_image(d+i*dest_channel_stride, dest.nr(), dest.nc(), dest_row_stride); resize_image(simg, dimg); }); } void resize_bilinear_gradient ( tensor& grad, long long grad_row_stride, long long grad_channel_stride, const tensor& gradient_input, long long gradient_input_row_stride, long long gradient_input_channel_stride ) { DLIB_CASSERT(is_same_object(grad, gradient_input)==false); DLIB_CASSERT(gradient_input.num_samples() == grad.num_samples()); DLIB_CASSERT(gradient_input.k() == grad.k()); if (gradient_input.size() == 0 || grad.size() == 0) return; const float* gi = gradient_input.host(); float* g = grad.host(); const float x_scale = (grad.nc()-1)/(float)std::max((gradient_input.nc()-1),1); const float y_scale = (grad.nr()-1)/(float)std::max((gradient_input.nr()-1),1); for (long long samp = 0; samp < gradient_input.num_samples(); ++samp) { for (long long k = 0; k < gradient_input.k(); ++k) { for (long long r = 0; r < gradient_input.nr(); ++r) { const float y = r*y_scale; const long long top = static_cast(std::floor(y)); const long long bottom = std::min(top+1, grad.nr()-1); const float tb_frac = y - top; for (long long c = 0; c < gradient_input.nc(); ++c) { const float x = c*x_scale; const long long left = static_cast(std::floor(x)); const long long right = std::min(left+1, grad.nc()-1); const float lr_frac = x - left; const float tmp = gi[r*gradient_input_row_stride+c]; g[top*grad_row_stride+left] += tmp*(1-tb_frac)*(1-lr_frac); g[top*grad_row_stride+right] += tmp*(1-tb_frac)*(lr_frac); g[bottom*grad_row_stride+left] += tmp*(tb_frac)*(1-lr_frac); g[bottom*grad_row_stride+right] += tmp*(tb_frac)*(lr_frac); } } g += grad_channel_stride; gi += gradient_input_channel_stride; } } } // ---------------------------------------------------------------------------------------- void reorg( bool add_to, tensor& dest, const int row_stride, const int col_stride, const tensor& src ) { DLIB_CASSERT(!is_same_object(dest, src), "Destination and source must be distinct objects."); DLIB_CASSERT(src.nr() % row_stride == 0, "The number of rows in src must be divisible by row_stride."); DLIB_CASSERT(src.nc() % col_stride == 0, "The number of columns in src must be divisible by col_stride."); DLIB_CASSERT(dest.num_samples() == src.num_samples(), "The number of samples must match."); DLIB_CASSERT(dest.k() == src.k() * row_stride * col_stride, "The number of channels must match."); DLIB_CASSERT(dest.nr() == src.nr() / row_stride, "The number of rows must match."); DLIB_CASSERT(dest.nc() == src.nc() / col_stride, "The number of columns must match."); const float* s = src.host(); float* d = dest.host(); const size_t sk = src.k(), snr = src.nr(), snc = src.nc(); const size_t dk = dest.k(), dnr = dest.nr(), dnc = dest.nc(), dsize = dest.size(); dlib::parallel_for(0, dsize, [&](long i) { const size_t out_plane_size = dnr * dnc; const size_t out_sample_size = dk * out_plane_size; const size_t n = i / out_sample_size; const size_t out_idx = i % out_sample_size; const size_t out_k = out_idx / out_plane_size; const size_t out_rc = out_idx % out_plane_size; const size_t out_r = out_rc / dnc; const size_t out_c = out_rc % dnc; const size_t in_k = out_k % sk; const size_t in_r = out_r * row_stride + (out_k / sk) / col_stride; const size_t in_c = out_c * col_stride + (out_k / sk) % col_stride; const size_t in_idx = ((n * sk + in_k) * snr + in_r) * snc + in_c; if (add_to) d[i] += s[in_idx]; else d[i] = s[in_idx]; }); } void reorg_gradient( bool add_to, tensor& grad, const int row_stride, const int col_stride, const tensor& gradient_input ) { DLIB_CASSERT(!is_same_object(grad, gradient_input), "Grad and gradient_input must be distinct objects."); DLIB_CASSERT(grad.nr() % row_stride == 0, "The number of rows in grad must be divisible by row_stride."); DLIB_CASSERT(grad.nc() % col_stride == 0, "The number of columns in grad must be divisible by col_stride."); DLIB_CASSERT(grad.num_samples() == gradient_input.num_samples(), "The number of samples in grad and gradient_input must match."); DLIB_CASSERT(grad.k() == gradient_input.k() / row_stride / col_stride, "The number of channels in grad must be gradient_input.k() divided by row_stride and col_stride."); DLIB_CASSERT(grad.nr() == gradient_input.nr() * row_stride, "The number of rows in grad must be gradient_input.nr() multiplied by row_stride."); DLIB_CASSERT(grad.nc() == gradient_input.nc() * col_stride, "The number of columns in grad must be gradient_input.nc() multiplied by col_stride."); const float* gi = gradient_input.host(); float* g = grad.host(); parallel_for(0, gradient_input.num_samples(), [&](long n) { for (long k = 0; k < gradient_input.k(); ++k) { for (long r = 0; r < gradient_input.nr(); ++r) { for (long c = 0; c < gradient_input.nc(); ++c) { const auto in_idx = tensor_index(gradient_input, n, k, r, c); const auto out_idx = tensor_index(grad, n, k % grad.k(), r * row_stride + (k / grad.k()) / col_stride, c * col_stride + (k / grad.k()) % col_stride); if (add_to) g[out_idx] += gi[in_idx]; else g[out_idx] = gi[in_idx]; } } } }); } // ------------------------------------------------------------------------------------ void embeddings( resizable_tensor& dest, const tensor& src, const tensor& embs ) { DLIB_CASSERT( src.nr() > 0 && embs.num_samples() > 0 && embs.k() > 0 && embs.nr() == 1 && embs.nc() == 1, "\nsrc.num_samples(): " << src.num_samples() << "\nsrc.k(): " << src.k() << "\nsrc.nr(): " << src.nr() << "\nsrc.nc(): " << src.nc() << "\nembs.num_samples(): " << embs.num_samples() << "\nembs.k(): " << embs.k() << "\nembs.nr(): " << embs.nr() << "\nembs.nc(): " << embs.nc() ); long ns = dest.num_samples(), nk = dest.k(), nr = dest.nr(), nc = dest.nc(); const float* src_data = src.host(); float* dest_data = dest.host(); const float* embs_data = embs.host(); for (long s = 0; s < ns; ++s) { for (long k = 0; k < nk; ++k) { for (long r = 0; r < nr; ++r) { const unsigned long token_idx = static_cast(src_data[tensor_index(src, s, k, r, 0)]); if (token_idx < embs.num_samples()) { for (long c = 0; c < nc; ++c) dest_data[tensor_index(dest, s, k, r, c)] = embs_data[tensor_index(embs, token_idx, c, 0, 0)]; } else { for (long c = 0; c < nc; ++c) dest_data[tensor_index(dest, s, k, r, c)] = 0; } } } } } void embeddings_gradient( const tensor& prev, const tensor& gradient_input, tensor& grads, const tensor& freqs, float learning_rate, bool scale ) { DLIB_CASSERT( prev.nr() > 0 && gradient_input.num_samples() == prev.num_samples() && gradient_input.k() == prev.k() && gradient_input.nr() == prev.nr() && gradient_input.nc() == grads.k() && grads.num_samples() > 0 && grads.k() > 0 && grads.nr() == 1 && grads.nc() == 1, "\ngradient_input.num_samples(): " << gradient_input.num_samples() << "\ngradient_input.k(): " << gradient_input.k() << "\ngradient_input.nr(): " << gradient_input.nr() << "\ngradient_input.nc(): " << gradient_input.nc() << "\nprev.num_samples(): " << prev.num_samples() << "\nprev.k(): " << prev.k() << "\nprev.nr(): " << prev.nr() << "\nprev.nc(): " << prev.nc() << "\ngrads.num_samples(): " << grads.num_samples() << "\ngrads.k(): " << grads.k() << "\ngrads.nr(): " << grads.nr() << "\ngrads.nc(): " << grads.nc() ); const float* prev_data = prev.host(); const float* gradient_input_data = gradient_input.host(); const float* freqs_data = freqs.host(); float* grads_data = grads.host(); long ns = gradient_input.num_samples(), nk = gradient_input.k(); long nr = gradient_input.nr(), nc = gradient_input.nc(); std::vector embedding_mutexes(grads.num_samples()); parallel_for(0, ns * nk, [&](long i) { long s = i / nk; long k = i % nk; for (long r = 0; r < nr; ++r) { const unsigned long token_idx = static_cast(prev_data[tensor_index(prev, s, k, r, 0)]); if (token_idx < grads.num_samples()) { const float freg_token = freqs_data[token_idx]; float freq_scale = 1.0f; if (scale && freg_token != 0.0f) freq_scale = std::min(0.15f, std::max(1.0f / freg_token, 1.0f)); auto_mutex locker(embedding_mutexes[token_idx]); for (long c = 0; c < nc; ++c) { const float gradient = gradient_input_data[tensor_index(gradient_input, s, k, r, c)]; grads_data[tensor_index(grads, token_idx, c, 0, 0)] -= (gradient * learning_rate * freq_scale); } } } }); } // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ pooling::pooling ( ) : window_height(0),window_width(0),stride_y(0),stride_x(0),padding_y(0),padding_x(0),do_max_pooling(true) { } void pooling:: clear( ) { window_height = 0; window_width = 0; stride_y = 0; stride_x = 0; padding_y = 0; padding_x = 0; } void pooling:: setup_max_pooling( int window_height_, int window_width_, int stride_y_, int stride_x_, int padding_y_, int padding_x_ ) { DLIB_CASSERT(window_width_ > 0); DLIB_CASSERT(window_height_ > 0); DLIB_CASSERT(stride_y_ > 0); DLIB_CASSERT(stride_x_ > 0); DLIB_CASSERT(0 <= padding_y_ && padding_y_ < window_height_); DLIB_CASSERT(0 <= padding_x_ && padding_x_ < window_width_); window_height = window_height_; window_width = window_width_; stride_y = stride_y_; stride_x = stride_x_; padding_y = padding_y_; padding_x = padding_x_; do_max_pooling = true; } void pooling:: setup_avg_pooling( int window_height_, int window_width_, int stride_y_, int stride_x_, int padding_y_, int padding_x_ ) { DLIB_CASSERT(window_width_ > 0); DLIB_CASSERT(window_height_ > 0); DLIB_CASSERT(stride_y_ > 0); DLIB_CASSERT(stride_x_ > 0); DLIB_CASSERT(0 <= padding_y_ && padding_y_ < window_height_); DLIB_CASSERT(0 <= padding_x_ && padding_x_ < window_width_); window_height = window_height_; window_width = window_width_; stride_y = stride_y_; stride_x = stride_x_; padding_y = padding_y_; padding_x = padding_x_; do_max_pooling = false; } void pooling:: operator() ( resizable_tensor& dest, const tensor& src ) { DLIB_CASSERT(window_width > 0); DLIB_CASSERT(window_height > 0); DLIB_CASSERT(stride_y > 0); DLIB_CASSERT(stride_x > 0); DLIB_CASSERT(0 <= padding_y && padding_y < window_height); DLIB_CASSERT(0 <= padding_x && padding_x < window_width); DLIB_CASSERT(window_width <= src.nc() + 2*padding_x, "Pooling windows must be small enough to fit into the padded image."); DLIB_CASSERT(window_height <= src.nr() + 2*padding_y, "Pooling windows must be small enough to fit into the padded image."); dest.set_size( src.num_samples(), src.k(), 1+(src.nr()+2*padding_y-window_height)/stride_y, 1+(src.nc()+2*padding_x-window_width)/stride_x ); if (src.size() == 0) { dest = 0; return; } auto d = dest.host(); const long x_offset = window_width/2 - padding_x; const long y_offset = window_height/2 - padding_y; if (does_max_pooling()) { for (long n = 0; n < dest.num_samples(); ++n) { for (long k = 0; k < dest.k(); ++k) { auto simg = image_plane(src,n,k); auto dimg = d + (n*dest.k() + k)*dest.nr()*dest.nc(); for (long r = 0; r < dest.nr(); ++r) { for (long c = 0; c < dest.nc(); ++c) { auto win = centered_rect(c*stride_x+x_offset, r*stride_y+y_offset, window_width, window_height); dimg[r*dest.nc() + c] = max(subm_clipped(simg,win)); } } } } } else { for (long n = 0; n < dest.num_samples(); ++n) { for (long k = 0; k < dest.k(); ++k) { auto simg = image_plane(src,n,k); auto dimg = d + (n*dest.k() + k)*dest.nr()*dest.nc(); for (long r = 0; r < dest.nr(); ++r) { for (long c = 0; c < dest.nc(); ++c) { auto win = centered_rect(c*stride_x+x_offset, r*stride_y+y_offset, window_width, window_height); dimg[r*dest.nc() + c] = mean(subm_clipped(simg,win)); } } } } } } void pooling::get_gradient( const tensor& gradient_input, const tensor& dest, const tensor& src, tensor& grad ) { DLIB_CASSERT(have_same_dimensions(gradient_input,dest)); DLIB_CASSERT(have_same_dimensions(src,grad)); if (src.size() == 0) { return; } auto gi = gradient_input.host(); auto g = grad.host(); const long x_offset = window_width/2 - padding_x; const long y_offset = window_height/2 - padding_y; if (does_max_pooling()) { for (long n = 0; n < dest.num_samples(); ++n) { for (long k = 0; k < dest.k(); ++k) { auto simg = image_plane(src,n,k); auto gimg = g + (n*grad.k() + k)*grad.nr()*grad.nc(); auto giimg = gi + (n*dest.k() + k)*dest.nr()*dest.nc(); auto imgbox = get_rect(simg); for (long r = 0; r < dest.nr(); ++r) { for (long c = 0; c < dest.nc(); ++c) { auto win = centered_rect(c*stride_x+x_offset, r*stride_y+y_offset, window_width, window_height).intersect(imgbox); auto p = max_point(subm(simg,win))+win.tl_corner(); gimg[p.y()*grad.nc()+p.x()] += giimg[r*dest.nc()+c]; } } } } } else { for (long n = 0; n < dest.num_samples(); ++n) { for (long k = 0; k < dest.k(); ++k) { auto simg = image_plane(src,n,k); auto gimg = g + (n*grad.k() + k)*grad.nr()*grad.nc(); auto giimg = gi + (n*dest.k() + k)*dest.nr()*dest.nc(); auto imgbox = get_rect(simg); for (long r = 0; r < dest.nr(); ++r) { for (long c = 0; c < dest.nc(); ++c) { auto win = centered_rect(c*stride_x+x_offset, r*stride_y+y_offset, window_width, window_height).intersect(imgbox); const float delta = giimg[r*dest.nc()+c]/win.area(); for (long y = win.top(); y <= win.bottom(); ++y) { for (long x = win.left(); x <= win.right(); ++x) { gimg[y*grad.nc()+x] += delta; } } } } } } } } // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ void img2col( matrix& output, const tensor& data, long n, long filter_nr, long filter_nc, long stride_y, long stride_x, long padding_y, long padding_x ) { const auto d = data.host() + data.k()*data.nr()*data.nc()*n; const rectangle boundary = get_rect(data); const long out_nr = 1+(data.nr()+2*padding_y-filter_nr)/stride_y; const long out_nc = 1+(data.nc()+2*padding_x-filter_nc)/stride_x; output.set_size(out_nr*out_nc, data.k()*filter_nr*filter_nc); DLIB_CASSERT(output.size() != 0); float* t = &output(0,0); // now fill in the Toeplitz output matrix for the n-th sample in data. long cnt = 0; const long max_r = data.nr() + padding_y-(filter_nr-1); const long max_c = data.nc() + padding_x-(filter_nc-1); for (long r = -padding_y; r < max_r; r+=stride_y) { for (long c = -padding_x; c < max_c; c+=stride_x) { for (long k = 0; k < data.k(); ++k) { for (long y = 0; y < filter_nr; ++y) { for (long x = 0; x < filter_nc; ++x) { DLIB_ASSERT(cnt < output.size()); long xx = c+x; long yy = r+y; if (boundary.contains(xx,yy)) *t = d[(k*data.nr() + yy)*data.nc() + xx]; else *t = 0; ++t; ++cnt; } } } } } } void col2img( const matrix& output, tensor& data, long n, long filter_nr, long filter_nc, long stride_y, long stride_x, long padding_y, long padding_x ) { const auto d = data.host() + data.k()*data.nr()*data.nc()*n; const rectangle boundary = get_rect(data); DLIB_CASSERT(output.size() != 0); const float* t = &output(0,0); // now fill in the Toeplitz output matrix for the n-th sample in data. const long max_r = data.nr() + padding_y-(filter_nr-1); const long max_c = data.nc() + padding_x-(filter_nc-1); for (long r = -padding_y; r < max_r; r+=stride_y) { for (long c = -padding_x; c < max_c; c+=stride_x) { for (long k = 0; k < data.k(); ++k) { for (long y = 0; y < filter_nr; ++y) { for (long x = 0; x < filter_nc; ++x) { long xx = c+x; long yy = r+y; if (boundary.contains(xx,yy)) d[(k*data.nr() + yy)*data.nc() + xx] += *t; ++t; } } } } } } void tensor_conv::operator() ( const bool add_to_output, resizable_tensor& output, const tensor& data, const tensor& filters ) { DLIB_CASSERT(last_stride_y > 0 && last_stride_x > 0, "You must call setup() before calling this function."); output.set_size(data.num_samples(), filters.num_samples(), 1+(data.nr()+2*last_padding_y-filters.nr())/last_stride_y, 1+(data.nc()+2*last_padding_x-filters.nc())/last_stride_x); (*this)(add_to_output, static_cast(output),data,filters); } void tensor_conv::operator() ( const bool add_to_output, tensor& output, const tensor& data, const tensor& filters ) { DLIB_CASSERT(is_same_object(output,data) == false); DLIB_CASSERT(is_same_object(output,filters) == false); DLIB_CASSERT(filters.k() == data.k()); DLIB_CASSERT(last_stride_y > 0 && last_stride_x > 0, "You must call setup() before calling this function."); DLIB_CASSERT(filters.nr() <= data.nr() + 2*last_padding_y, "Filter windows must be small enough to fit into the padded image."); DLIB_CASSERT(filters.nc() <= data.nc() + 2*last_padding_x, "Filter windows must be small enough to fit into the padded image."); DLIB_CASSERT(output.num_samples() == data.num_samples()); DLIB_CASSERT(output.k() == filters.num_samples()); DLIB_CASSERT(output.nr() == 1+(data.nr()+2*last_padding_y-filters.nr())/last_stride_y); DLIB_CASSERT(output.nc() == 1+(data.nc()+2*last_padding_x-filters.nc())/last_stride_x); matrix temp; for (long n = 0; n < data.num_samples(); ++n) { img2col(temp, data, n, filters.nr(), filters.nc(), last_stride_y, last_stride_x, last_padding_y, last_padding_x); if (add_to_output) output.add_to_sample(n, mat(filters)*trans(temp)); else output.set_sample(n, mat(filters)*trans(temp)); } } void tensor_conv::operator() ( const bool add_to_output, resizable_tensor& output, const tensor& data, const tensor& filters, const tensor& biases, bool use_relu ) { DLIB_CASSERT(filters.num_samples() == biases.k()); (*this)(add_to_output, output,data,filters); tt::add(1, output, 1, biases); if (use_relu) tt::relu(output, output); } void tensor_conv::operator() ( const bool add_to_output, tensor& output, const tensor& data, const tensor& filters, const tensor& biases, bool use_relu ) { DLIB_CASSERT(filters.num_samples() == biases.k()); (*this)(add_to_output, output, data, filters); tt::add(1, output, 1, biases); if (use_relu) tt::relu(output, output); } // ------------------------------------------------------------------------------------ void tensor_conv:: get_gradient_for_data ( const bool add_to_output, const tensor& gradient_input, const tensor& filters, tensor& data_gradient ) { matrix temp; if (!add_to_output) data_gradient = 0; for (long n = 0; n < gradient_input.num_samples(); ++n) { auto gi = mat(gradient_input.host()+gradient_input.k()*gradient_input.nr()*gradient_input.nc()*n, gradient_input.k(), gradient_input.nr()*gradient_input.nc()); temp = trans(gi)*mat(filters); col2img(temp, data_gradient, n, filters.nr(), filters.nc(), last_stride_y, last_stride_x, last_padding_y, last_padding_x); } } // ------------------------------------------------------------------------------------ void tensor_conv:: get_gradient_for_filters ( const bool add_to_output, const tensor& gradient_input, const tensor& data, tensor& filters_gradient ) { matrix temp; for (long n = 0; n < gradient_input.num_samples(); ++n) { auto gi = mat(gradient_input.host()+gradient_input.k()*gradient_input.nr()*gradient_input.nc()*n, gradient_input.k(), gradient_input.nr()*gradient_input.nc()); img2col(temp, data, n, filters_gradient.nr(), filters_gradient.nc(), last_stride_y, last_stride_x, last_padding_y, last_padding_x); if (n == 0) { if (add_to_output) filters_gradient += gi*temp; else filters_gradient = gi*temp; } else { filters_gradient += gi*temp; } } } // ------------------------------------------------------------------------------------ void copy_tensor( bool add_to, tensor& dest, size_t dest_k_offset, const tensor& src, size_t src_k_offset, size_t count_k ) { const size_t dest_sample_size = static_cast(dest.nc() * dest.nr() * dest.k()); const size_t src_sample_size = static_cast(src.nc() * src.nr() * src.k()); const size_t block_size = count_k * dest.nc() * dest.nr(); DLIB_CASSERT(dest.num_samples() == src.num_samples() && dest.nc() == src.nc() && dest.nr() == src.nr(), "All sources should fit into dest tensor size"); DLIB_CASSERT(dest.k() - dest_k_offset >= count_k, "Not enough space in dest tensor"); DLIB_CASSERT(src.k() - src_k_offset >= count_k, "Not enough space in src tensor"); float* dest_p = dest.host() + dest_k_offset * dest.nc() * dest.nr(); const float* src_p = src.host() + src_k_offset * src.nc() * src.nr(); for (long i = 0; i < src.num_samples(); ++i) { if (add_to) { for (size_t j = 0; j < block_size; ++j) dest_p[j] += src_p[j]; } else { ::memcpy(dest_p, src_p, block_size * sizeof(float)); } dest_p += dest_sample_size; src_p += src_sample_size; } } // ------------------------------------------------------------------------------------ void copy_tensor( bool add_to, tensor& dest, size_t dk, size_t dnr, size_t dnc, const tensor& src, size_t sk, size_t snr, size_t snc, size_t k, size_t nr, size_t nc ) { size_t dest_stride_sample = static_cast(dest.nc() * dest.nr() * dest.k()); size_t dest_stride_k = static_cast(dest.nc() * dest.nr()); size_t dest_stride_nr = static_cast(dest.nc()); size_t src_stride_sample = static_cast(src.nc() * src.nr() * src.k()); size_t src_stride_k = static_cast(src.nc() * src.nr()); size_t src_stride_nr = static_cast(src.nc()); DLIB_CASSERT(dest.num_samples() == src.num_samples(), "All sources should fit into dest tensor size"); DLIB_CASSERT(dest.k() - dk >= k && dest.nr() - dnr >= nr && dest.nc() - dnc >= nc, "Not enough space in dest tensor"); DLIB_CASSERT(src.k() - sk >= k && src.nr() - snr >= nr && src.nc() - snc >= nc, "Not enough space in src tensor"); float* dest_p = dest.host() + dk * dest_stride_k \ + dnr * dest_stride_nr \ + dnc; const float* src_p = src.host() + sk * src_stride_k \ + snr * src_stride_nr \ + snc; for (long i = 0; i < src.num_samples(); ++i) { float* dest_channel_p = dest_p; const float* src_channel_p = src_p; for (long j = 0; j < k; ++j) { float* dest_row_p = dest_channel_p; const float* src_row_p = src_channel_p; for (long r = 0; r < nr; ++r) { if (add_to) { for (size_t c = 0; c < nc; ++c) dest_row_p[c] += src_row_p[c]; } else { ::memcpy(dest_row_p, src_row_p, nc * sizeof(float)); } dest_row_p += dest_stride_nr; src_row_p += src_stride_nr; } dest_channel_p += dest_stride_k; src_channel_p += src_stride_k; } dest_p += dest_stride_sample; src_p += src_stride_sample; } } // ------------------------------------------------------------------------------------ void transpose( bool add, tensor& dest, const tensor& src ) { DLIB_CASSERT(dest.num_samples() == src.num_samples() && dest.k() == src.k() && dest.nr() == src.nc() && dest.nc() == src.nr(), "Incompatible tensor dimensions."); const float* src_data = src.host(); float* dest_data = dest.host(); const long num_samples = src.num_samples(); const long k_dim = src.k(); const long src_nr = src.nr(); const long src_nc = src.nc(); const long dest_nr = dest.nr(); const long dest_nc = dest.nc(); parallel_for(0, num_samples * k_dim, [&](long i) { const long n = i / k_dim; const long k = i % k_dim; const long src_nk_offset = (n * src.k() + k) * src_nr; const long dest_nk_offset = (n * dest.k() + k) * dest_nr; for (long r = 0; r < src_nr; ++r) { for (long c = 0; c < src_nc; ++c) { const long src_idx = (src_nk_offset + r) * src_nc + c; const long dest_idx = (dest_nk_offset + c) * dest_nc + r; if (add) dest_data[dest_idx] += src_data[src_idx]; else dest_data[dest_idx] = src_data[src_idx]; } } }); } // ------------------------------------------------------------------------------------ void compute_act_halt_probabilities( resizable_tensor& halt_probs, resizable_tensor& logits, const tensor& input_data, const tensor& halt_params, long batch_size, long seq_len, long feature_dim ) { const float* in_ptr = input_data.host(); const float* W_halt = halt_params.host(); const float b_halt = halt_params.host()[feature_dim]; float* logits_ptr = logits.host(); float* halt_probs_ptr = halt_probs.host(); const long d_model = feature_dim / input_data.k(); const long num_channels = input_data.k(); for (long pos = 0; pos < batch_size * seq_len; ++pos) { const long n = pos / seq_len; const long s = pos % seq_len; float logit = b_halt; for (long c = 0; c < num_channels; ++c) { for (long d = 0; d < d_model; ++d) { const long in_idx = ((n * num_channels + c) * seq_len + s) * d_model + d; const long weight_idx = c * d_model + d; logit += in_ptr[in_idx] * W_halt[weight_idx]; } } logits_ptr[pos] = logit; halt_probs_ptr[pos] = 1.0f / (1.0f + std::exp(-logit)); } } void update_act_state( resizable_tensor& output, const tensor& input_data, const tensor& halt_probs, resizable_tensor& cumulative_halting, resizable_tensor& remainders, resizable_tensor& n_steps, resizable_tensor& effective_weights, long batch_size, long seq_len, long d_model, long num_channels, float halt_threshold, long current_step ) { const float* in_ptr = input_data.host(); const float* p_halt = halt_probs.host(); float* out_ptr = output.host(); float* cum_halt = cumulative_halting.host(); float* remain = remainders.host(); float* steps = n_steps.host(); float* eff_weights = effective_weights.host(); for (long pos = 0; pos < batch_size * seq_len; ++pos) { if (cum_halt[pos] < halt_threshold) { const long n = pos / seq_len; const long s = pos % seq_len; float p = p_halt[pos]; float r = remain[pos]; float effective = std::min(p * r, halt_threshold - cum_halt[pos]); cum_halt[pos] += effective; remain[pos] -= effective; steps[pos] = static_cast(current_step + 1); eff_weights[pos] += effective; for (long c = 0; c < num_channels; ++c) { for (long d = 0; d < d_model; ++d) { const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d; out_ptr[idx] += effective * in_ptr[idx]; } } } } } void finalize_act_output( resizable_tensor& output, const tensor& input_data, const tensor& remainders, resizable_tensor& effective_weights, long batch_size, long seq_len, long d_model, long num_channels ) { const float* in_ptr = input_data.host(); const float* remain = remainders.host(); float* out_ptr = output.host(); float* eff_weights = effective_weights.host(); for (long pos = 0; pos < batch_size * seq_len; ++pos) { float r = remain[pos]; if (r > 1e-6f) { const long n = pos / seq_len; const long s = pos % seq_len; eff_weights[pos] += r; for (long c = 0; c < num_channels; ++c) { for (long d = 0; d < d_model; ++d) { const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d; out_ptr[idx] += r * in_ptr[idx]; } } } } } void apply_act_depth_scaling( tensor& gradients, const tensor& n_steps, long batch_size, long seq_len, long d_model, long num_channels, float max_steps, float scale_factor ) { const float* steps = n_steps.host(); float* grad_ptr = gradients.host(); for (long pos = 0; pos < batch_size * seq_len; ++pos) { const float scale = 1.0f + scale_factor * (steps[pos] / max_steps); const long n = pos / seq_len; const long s = pos % seq_len; for (long c = 0; c < num_channels; ++c) { for (long d = 0; d < d_model; ++d) { const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d; grad_ptr[idx] *= scale; } } } } // ------------------------------------------------------------------------------------ } } #endif // DLIB_DNN_CPU_cPP_ ================================================ FILE: dlib/cuda/cpu_dlib.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CPU_H_ #define DLIB_DNN_CPU_H_ // This file contains CPU implementations of the GPU based functions in cuda_dlib.h // and cudnn_dlibapi.h #include "tensor.h" #include "../geometry/rectangle.h" #include "../dnn/utilities.h" namespace dlib { namespace cpu { // ----------------------------------------------------------------------------------- void multiply ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ); void multiply_conv ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ); void multiply_zero_padded ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ); void scale_channels ( bool add_to, tensor& dest, const tensor& src, const tensor& scales ); void add( float beta, tensor& dest, float alpha, const tensor& src ); void assign_bias_gradient ( tensor& grad, const tensor& gradient_input ); void add ( tensor& dest, const tensor& src1, const tensor& src2 ); void assign_conv_bias_gradient ( tensor& grad, const tensor& gradient_input ); // ----------------------------------------------------------------------------------- void affine_transform( tensor& dest, const tensor& src, const float A, const float B ); void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const float A, const float B, const float C ); void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C, const float D ); void affine_transform_range( size_t begin, size_t end, tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C ); // ----------------------------------------------------------------------------------- void affine_transform( tensor& dest, const tensor& src, const tensor& A, const tensor& B ); // ----------------------------------------------------------------------------------- void affine_transform_conv( tensor& dest, const tensor& src, const tensor& A, const tensor& B ); // ----------------------------------------------------------------------------------- void affine_transform( const rectangle& rect, tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, float A, float B, float C ); // ----------------------------------------------------------------------------------- void compute_adam_update ( size_t begin, size_t end, tensor& s, tensor& m, tensor& v, const float t, const float learning_rate, const float weight_decay, const float momentum1, const float momentum2, const tensor& params, const tensor& params_grad ); // ----------------------------------------------------------------------------------- void batch_normalize_inference ( const double eps, resizable_tensor& dest, const tensor& src, const tensor& gamma, const tensor& beta, const tensor& running_means, const tensor& running_variances ); void batch_normalize ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const double averaging_factor, resizable_tensor& running_means, resizable_tensor& running_variances, const tensor& src, const tensor& gamma, const tensor& beta ); void batch_normalize_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad ); void batch_normalize_conv_inference ( const double eps, resizable_tensor& dest, const tensor& src, const tensor& gamma, const tensor& beta, const tensor& running_means, const tensor& running_variances ); void batch_normalize_conv ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const double averaging_factor, resizable_tensor& running_means, resizable_tensor& running_variances, const tensor& src, const tensor& gamma, const tensor& beta ); void batch_normalize_conv_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad ); // ----------------------------------------------------------------------------------- void layer_normalize ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const tensor& src, const tensor& gamma, const tensor& beta ); void layer_normalize_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad, resizable_tensor& dmeans, resizable_tensor& dvars ); // ----------------------------------------------------------------------------------- void rms_normalize( const double eps, resizable_tensor& dest, resizable_tensor& scale, const tensor& src, const tensor& gamma ); void rms_normalize_gradient( const tensor& gradient_input, const tensor& scale, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, resizable_tensor& dscale ); // ----------------------------------------------------------------------------------- void threshold ( tensor& data, float thresh ); void dot ( const tensor& a, const tensor& b, tensor& result, size_t idx ); // ----------------------------------------------------------------------------------- void softmax( tensor& dest, const tensor& src, operation_mode mode = operation_mode::CHANNEL_WISE ); void softmax_gradient( tensor& grad, const tensor& dest, const tensor& gradient_input, operation_mode mode = operation_mode::CHANNEL_WISE ); // ------------------------------------------------------------------------------------ void softmax_all ( tensor& dest, const tensor& src ); void softmax_all_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); // ------------------------------------------------------------------------------------ void sigmoid ( tensor& dest, const tensor& src ); void sigmoid_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); // ------------------------------------------------------------------------------------ void mish ( tensor& dest, const tensor& src ); void mish_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ); // ------------------------------------------------------------------------------------ void relu ( tensor& dest, const tensor& src ); void relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); // ---------------------------------------------------------------------------------------- void prelu ( tensor& dest, const tensor& src, const tensor& param ); void prelu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input, const tensor& param, tensor& params_grad ); // ------------------------------------------------------------------------------------ void leaky_relu ( tensor& dest, const tensor& src, const float alpha ); void leaky_relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float alpha ); // ------------------------------------------------------------------------------------ void tanh ( tensor& dest, const tensor& src ); void tanh_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); // ------------------------------------------------------------------------------------ void clipped_relu ( tensor& dest, const tensor& src, const float ceiling ); void clipped_relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float ceiling ); // ------------------------------------------------------------------------------------ void elu ( tensor& dest, const tensor& src, const float alpha ); void elu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float alpha ); // ---------------------------------------------------------------------------------------- void gelu ( tensor& dest, const tensor& src ); void gelu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); // ---------------------------------------------------------------------------------------- void smelu ( tensor& dest, const tensor& src, const float beta ); void smelu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float beta ); // ---------------------------------------------------------------------------------------- void silu ( tensor& dest, const tensor& src ); void silu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); // ------------------------------------------------------------------------------------ void resize_bilinear ( tensor& dest, long long dest_row_stride, long long dest_channel_stride, const tensor& src, long long src_row_stride, long long src_channel_stride ); void resize_bilinear_gradient ( tensor& grad, long long grad_row_stride, long long grad_channel_stride, const tensor& gradient_input, long long gradient_input_row_stride, long long gradient_input_channel_stride ); inline void resize_bilinear ( tensor& dest, const tensor& src ) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); } inline void resize_bilinear_gradient ( tensor& grad, const tensor& gradient_input ) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); } // ----------------------------------------------------------------------------------- void reorg ( bool add_to, tensor& dest, const int row_stride, const int col_stride, const tensor& src ); void reorg_gradient ( bool add_to, tensor& grad, const int row_stride, const int col_stride, const tensor& gradient_input ); // ----------------------------------------------------------------------------------- void embeddings( resizable_tensor& dest, const tensor& src, const tensor& embs ); void embeddings_gradient( const tensor& prev, const tensor& gradient_input, tensor& grads, const tensor& freqs, float learning_rate, bool scale ); // ----------------------------------------------------------------------------------- void compute_act_halt_probabilities( resizable_tensor& halt_probs, resizable_tensor& logits, const tensor& input_data, const tensor& halt_params, long batch_size, long seq_len, long feature_dim ); void update_act_state( resizable_tensor& output, const tensor& input_data, const tensor& halt_probs, resizable_tensor& cumulative_halting, resizable_tensor& remainders, resizable_tensor& n_steps, resizable_tensor& effective_weights, long batch_size, long seq_len, long d_model, long num_channels, float halt_threshold, long current_step ); void finalize_act_output( resizable_tensor& output, const tensor& input_data, const tensor& remainders, resizable_tensor& effective_weights, long batch_size, long seq_len, long d_model, long num_channels ); void apply_act_depth_scaling( tensor& gradients, const tensor& n_steps, long batch_size, long seq_len, long d_model, long num_channels, float max_steps, float scale_factor ); // ----------------------------------------------------------------------------------- class pooling { public: pooling(const pooling&) = delete; pooling& operator=(const pooling&) = delete; pooling ( ); void clear( ); void setup_max_pooling( int window_height, int window_width, int stride_y, int stride_x, int padding_y, int padding_x ); void setup_avg_pooling( int window_height, int window_width, int stride_y, int stride_x, int padding_y, int padding_x ); bool does_max_pooling( ) const { return do_max_pooling; } void operator() ( resizable_tensor& dest, const tensor& src ); void get_gradient( const tensor& gradient_input, const tensor& dest, const tensor& src, tensor& grad ); private: int window_height; int window_width; int stride_y; int stride_x; int padding_y; int padding_x; bool do_max_pooling; }; // ----------------------------------------------------------------------------------- class tensor_conv { public: tensor_conv(const tensor_conv&) = delete; tensor_conv& operator=(const tensor_conv&) = delete; tensor_conv() {} void clear( ) {} void setup( const tensor& data, /* not used but required for interface */ const tensor& filters, /* not used but required for interface */ int stride_y, int stride_x, int padding_y, int padding_x ) { (void)data; /* silence compiler */ DLIB_CASSERT(stride_y > 0 && stride_x > 0); DLIB_CASSERT(0 <= padding_y && padding_y < filters.nr()); DLIB_CASSERT(0 <= padding_x && padding_x < filters.nc()); last_stride_y = stride_y; last_stride_x = stride_x; last_padding_y = padding_y; last_padding_x = padding_x; } void operator() ( const bool add_to_output, resizable_tensor& output, const tensor& data, const tensor& filters ); void operator() ( const bool add_to_output, tensor& output, const tensor& data, const tensor& filters ); void operator() ( const bool add_to_output, resizable_tensor& output, const tensor& data, const tensor& filters, const tensor& biases, bool use_relu ); void operator() ( const bool add_to_output, tensor& output, const tensor& data, const tensor& filters, const tensor& biases, bool use_relu ); void get_gradient_for_data ( const bool add_to_output, const tensor& gradient_input, const tensor& filters, tensor& data_gradient ); void get_gradient_for_filters ( const bool add_to_output, const tensor& gradient_input, const tensor& data, tensor& filters_gradient ); private: long last_stride_y = 0; long last_stride_x = 0; long last_padding_y = 0; long last_padding_x = 0; }; // ----------------------------------------------------------------------------------- void copy_tensor( bool add_to, tensor& dest, size_t dest_k_offset, const tensor& src, size_t src_k_offset, size_t count_k ); // ----------------------------------------------------------------------------------- void copy_tensor( bool add_to, tensor& dest, size_t dk, size_t dnr, size_t dnc, const tensor& src, size_t sk, size_t snr, size_t snc, size_t k, size_t nr, size_t nc ); // ----------------------------------------------------------------------------------- void transpose( bool add_to, tensor& dest, const tensor& src ); // ----------------------------------------------------------------------------------- class compute_loss_binary_log_per_pixel { /*! The point of this class is to compute the loss for loss_binary_log_per_pixel_ on the cpu to provide an analogous implementation of the cuda version !*/ public: compute_loss_binary_log_per_pixel( ) { } template < typename const_label_iterator > void operator()( const_label_iterator truth, const tensor& output_tensor, tensor& grad, double& loss ) const { sigmoid(grad, output_tensor); // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. const double scale = 1.0/(output_tensor.num_samples()*output_tensor.nr()*output_tensor.nc()); loss = 0; float* const g = grad.host(); const float* const out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth) { for (long r = 0; r < output_tensor.nr(); ++r) { for (long c = 0; c < output_tensor.nc(); ++c) { const float y = truth->operator()(r, c); const size_t idx = tensor_index(output_tensor, i, 0, r, c); if (y > 0.f) { const float temp = log1pexp(-out_data[idx]); loss += y*scale*temp; g[idx] = y*scale*(g[idx]-1); } else if (y < 0.f) { const float temp = -(-out_data[idx]-log1pexp(-out_data[idx])); loss += -y*scale*temp; g[idx] = -y*scale*g[idx]; } else { g[idx] = 0.f; } } } } } }; // ----------------------------------------------------------------------------------- class compute_loss_multiclass_log_per_pixel { /*! The point of this class is to compute the loss for loss_multiclass_log_per_pixel_ on the cpu to provide an analogous implementation of the cuda version !*/ public: compute_loss_multiclass_log_per_pixel( ) { } template < typename const_label_iterator > void operator()( const_label_iterator truth, const tensor& output_tensor, tensor& grad, double& loss ) const { softmax(grad, output_tensor); // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.nr() * output_tensor.nc()); loss = 0; float* const g = grad.host(); for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth) { for (long r = 0; r < output_tensor.nr(); ++r) { for (long c = 0; c < output_tensor.nc(); ++c) { const uint16_t y = truth->operator()(r, c); // The network must produce a number of outputs that is equal to the number // of labels when using this type of loss. DLIB_CASSERT(static_cast(y) < output_tensor.k() || y == label_to_ignore, "y: " << y << ", output_tensor.k(): " << output_tensor.k()); for (long k = 0; k < output_tensor.k(); ++k) { const size_t idx = tensor_index(output_tensor, i, k, r, c); if (k == y) { loss += scale*-safe_log(g[idx]); g[idx] = scale*(g[idx] - 1); } else if (y == label_to_ignore) { g[idx] = 0.f; } else { g[idx] = scale*g[idx]; } } } } } } private: static const uint16_t label_to_ignore = std::numeric_limits::max(); }; // ----------------------------------------------------------------------------------- class compute_loss_multiclass_log_per_pixel_weighted { /*! The point of this class is to compute the loss for loss_multiclass_log_per_pixel_weighted_ on the cpu to provide an analogous implementation of the cuda version !*/ public: compute_loss_multiclass_log_per_pixel_weighted( ) { } template < typename const_label_iterator > void operator()( const_label_iterator truth, const tensor& output_tensor, tensor& grad, double& loss ) const { softmax(grad, output_tensor); // The loss we output is the weighted average loss over the mini-batch, and also over each element of the matrix output. const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.nr() * output_tensor.nc()); loss = 0; float* const g = grad.host(); for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth) { for (long r = 0; r < output_tensor.nr(); ++r) { for (long c = 0; c < output_tensor.nc(); ++c) { const weighted_label& weighted_label = truth->operator()(r, c); const uint16_t y = weighted_label.label; const float weight = weighted_label.weight; // The network must produce a number of outputs that is equal to the number // of labels when using this type of loss. DLIB_CASSERT(static_cast(y) < output_tensor.k() || weight == 0.f, "y: " << y << ", output_tensor.k(): " << output_tensor.k()); for (long k = 0; k < output_tensor.k(); ++k) { const size_t idx = tensor_index(output_tensor, i, k, r, c); if (k == y) { loss += weight*scale*-safe_log(g[idx]); g[idx] = weight*scale*(g[idx] - 1); } else { g[idx] = weight*scale*g[idx]; } } } } } } }; // ----------------------------------------------------------------------------------- class compute_loss_mean_squared_per_channel_and_pixel { /*! The point of this class is to compute the loss for loss_mean_squared_per_channel_and_pixel_ on the cpu to provide an analogous implementation of the cuda version !*/ public: compute_loss_mean_squared_per_channel_and_pixel( ) { } template < typename const_label_iterator > void operator()( const_label_iterator truth, const tensor& output_tensor, tensor& grad, double& loss ) const { // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.k() * output_tensor.nr() * output_tensor.nc()); loss = 0; float* const g = grad.host(); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth) { for (long k = 0; k < output_tensor.k(); ++k) { for (long r = 0; r < output_tensor.nr(); ++r) { for (long c = 0; c < output_tensor.nc(); ++c) { const float y = (*truth)[k].operator()(r, c); const size_t idx = tensor_index(output_tensor, i, k, r, c); const float temp1 = y - out_data[idx]; const float temp2 = scale*temp1; loss += temp2*temp1; g[idx] = -temp2; } } } } } }; // ----------------------------------------------------------------------------------- } } #ifdef NO_MAKEFILE #include "cpu_dlib.cpp" #endif #endif // DLIB_DNN_CPU_H_ ================================================ FILE: dlib/cuda/cublas_dlibapi.cpp ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuBLAS_CPP_ #define DLIB_DNN_CuBLAS_CPP_ #ifdef DLIB_USE_CUDA #include "cublas_dlibapi.h" #include "cuda_utils.h" #include #include static const char* cublas_get_error_string(cublasStatus_t s) { switch(s) { case CUBLAS_STATUS_NOT_INITIALIZED: return "CUDA Runtime API initialization failed."; case CUBLAS_STATUS_ALLOC_FAILED: return "CUDA Resources could not be allocated."; default: return "A call to cuBLAS failed"; } } // Check the return value of a call to the cuBLAS runtime for an error condition. #define CHECK_CUBLAS(call) \ do{ \ const cublasStatus_t error = call; \ if (error != CUBLAS_STATUS_SUCCESS) \ { \ std::ostringstream sout; \ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ sout << "code: " << error << ", reason: " << cublas_get_error_string(error);\ throw dlib::cublas_error(sout.str()); \ } \ }while(false) namespace dlib { namespace cuda { // ----------------------------------------------------------------------------------- class cublas_context { public: // not copyable cublas_context(const cublas_context&) = delete; cublas_context& operator=(const cublas_context&) = delete; cublas_context() { handles.resize(16); } ~cublas_context() { for (auto h : handles) { if (h) cublasDestroy(h); } } cublasHandle_t get_handle ( ) { int new_device_id; CHECK_CUDA(cudaGetDevice(&new_device_id)); // make room for more devices if needed if (new_device_id >= (long)handles.size()) handles.resize(new_device_id+16); // If we don't have a handle already for this device then make one if (!handles[new_device_id]) CHECK_CUBLAS(cublasCreate(&handles[new_device_id])); // Finally, return the handle for the current device return handles[new_device_id]; } private: std::vector handles; }; static cublasHandle_t context() { thread_local cublas_context c; return c.get_handle(); } // ----------------------------------------------------------------------------------- void gemm ( float beta, tensor& dest, float alpha, const tensor& lhs, bool trans_lhs, const tensor& rhs, bool trans_rhs, operation_mode mode ) { if (mode == operation_mode::CHANNEL_WISE) { // Recall that BLAS uses column major order so to deal with that we flip the // order of the lhs and rhs arguments. const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N; const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N; const int dest_nr = dest.num_samples(); const int dest_nc = dest.size() / dest_nr; const int lhs_nr = lhs.num_samples(); const int lhs_nc = lhs.size() / lhs_nr; const int rhs_nr = rhs.num_samples(); const int rhs_nc = rhs.size() / rhs_nr; if (trans_lhs && trans_rhs) { DLIB_ASSERT(dest_nr == lhs_nc && dest_nc == rhs_nr && lhs_nr == rhs_nc) } else if (!trans_lhs && trans_rhs) { DLIB_ASSERT(dest_nr == lhs_nr && dest_nc == rhs_nr && lhs_nc == rhs_nc) } else if (trans_lhs && !trans_rhs) { DLIB_ASSERT(dest_nr == lhs_nc && dest_nc == rhs_nc && lhs_nr == rhs_nr) } else { DLIB_ASSERT(dest_nr == lhs_nr && dest_nc == rhs_nc && lhs_nc == rhs_nr) } const int k = trans_rhs ? rhs_nc : rhs_nr; CHECK_CUBLAS(cublasSgemm(context(), transb, transa, dest_nc, dest_nr, k, &alpha, rhs.device(), rhs_nc, lhs.device(), lhs_nc, &beta, dest.device(), dest_nc)); } else if (mode == operation_mode::PLANE_WISE) { const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N; const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N; long num_samples = std::min({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() }); long num_channels = std::min({ lhs.k(), rhs.k(), dest.k() }); auto is_matrix = [](const auto& tensor) { return ((tensor.num_samples() * tensor.k() == 1 && tensor.nr() * tensor.nc() > 1) || (tensor.num_samples() * tensor.k() > 1 && tensor.nr() * tensor.nc() == 1)); }; const bool lhs_is_matrix = is_matrix(lhs), rhs_is_matrix = is_matrix(rhs), dest_is_matrix = is_matrix(dest); if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix) num_samples = num_channels = 1; size_t lhs_rows = lhs.nr(); size_t lhs_cols = lhs.nc(); if (lhs_is_matrix && (lhs.num_samples() > 1 || lhs.k() > 1)) { lhs_rows = lhs.num_samples(); lhs_cols = lhs.k(); } size_t rhs_rows = rhs.nr(); size_t rhs_cols = rhs.nc(); if (rhs_is_matrix && (rhs.num_samples() > 1 || rhs.k() > 1)) { rhs_rows = rhs.num_samples(); rhs_cols = rhs.k(); } size_t dest_rows = dest.nr(); size_t dest_cols = dest.nc(); if (dest_is_matrix && (dest.num_samples() > 1 || dest.k() > 1)) { dest_rows = dest.num_samples(); dest_cols = dest.k(); } const size_t lhs_plane_size = lhs_rows * lhs_cols; const size_t rhs_plane_size = rhs_rows * rhs_cols; const size_t dest_plane_size = dest_rows * dest_cols; for (long b = 0; b < num_samples; ++b) { for (long c = 0; c < num_channels; ++c) { auto lhs_slice = lhs_is_matrix ? lhs.device() : lhs.device() + (b * num_channels + c) * lhs_plane_size; auto rhs_slice = rhs_is_matrix ? rhs.device() : rhs.device() + (b * num_channels + c) * rhs_plane_size; auto dest_slice = dest_is_matrix ? dest.device() : dest.device() + (b * num_channels + c) * dest_plane_size; const int k = trans_rhs ? rhs_cols : rhs_rows; CHECK_CUBLAS(cublasSgemm( context(), transb, transa, dest_cols, dest_rows, k, &alpha, rhs_slice, rhs_cols, lhs_slice, lhs_cols, &beta, dest_slice, dest_cols )); } } } } // ------------------------------------------------------------------------------------ } } #endif // DLIB_USE_CUDA #endif // DLIB_DNN_CuBLAS_CPP_ ================================================ FILE: dlib/cuda/cublas_dlibapi.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuBLAS_H_ #define DLIB_DNN_CuBLAS_H_ #ifdef DLIB_USE_CUDA #include "tensor.h" #include "cuda_errors.h" namespace dlib { namespace cuda { // ----------------------------------------------------------------------------------- void gemm ( float beta, tensor& dest, float alpha, const tensor& lhs, bool trans_lhs, const tensor& rhs, bool trans_rhs, operation_mode mode = operation_mode::CHANNEL_WISE ); /*! requires - The dimensions of lhs and rhs must be compatible for matrix multiplication. The specific requirements depend on the mode: For CHANNEL_WISE mode (default): - Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs) - Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs) - Let D == mat(dest) - D.nr() == L.nr() && D.nc() == R.nc() (i.e. dest must be preallocated and have the correct output dimensions) - L.nc() == R.nr() For PLANE_WISE mode: - lhs.num_samples() == rhs.num_samples() && lhs.k() == rhs.k() - If !trans_lhs && !trans_rhs: lhs.nc() == rhs.nr() dest.nr() == lhs.nr() && dest.nc() == rhs.nc() - If trans_lhs && !trans_rhs: lhs.nr() == rhs.nr() dest.nr() == lhs.nc() && dest.nc() == rhs.nc() - If !trans_lhs && trans_rhs: lhs.nc() == rhs.nc() dest.nr() == lhs.nr() && dest.nc() == rhs.nr() - If trans_lhs && trans_rhs: lhs.nr() == rhs.nc() dest.nr() == lhs.nc() && dest.nc() == rhs.nr() ensures - Performs matrix multiplication based on the specified mode: For CHANNEL_WISE mode: - performs: dest = alpha*L*R + beta*mat(dest) where L, R, and D are as defined above. For PLANE_WISE mode: - Performs matrix multiplication for each corresponding 2D plane (nr x nc) in lhs and rhs across all samples and channels. - The operation is equivalent to performing the following for each sample and channel: dest[s][k] = alpha * (lhs[s][k] * rhs[s][k]) + beta * dest[s][k] where [s][k] represents the 2D plane for sample s and channel k. !*/ // ------------------------------------------------------------------------------------ } } #endif // DLIB_USE_CUDA #endif // DLIB_DNN_CuBLAS_H_ ================================================ FILE: dlib/cuda/cuda_data_ptr.cpp ================================================ // Copyright (C) 2017 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuDA_DATA_PTR_CPP_ #define DLIB_DNN_CuDA_DATA_PTR_CPP_ #ifdef DLIB_USE_CUDA #include "cuda_data_ptr.h" #include "cuda_utils.h" namespace dlib { namespace cuda { // ---------------------------------------------------------------------------------------- weak_cuda_data_void_ptr:: weak_cuda_data_void_ptr( const cuda_data_void_ptr& ptr ) : num(ptr.num), pdata(ptr.pdata) { } // ---------------------------------------------------------------------------------------- cuda_data_void_ptr weak_cuda_data_void_ptr:: lock() const { auto ptr = pdata.lock(); if (ptr) { cuda_data_void_ptr temp; temp.pdata = ptr; temp.num = num; return temp; } else { return cuda_data_void_ptr(); } } // ----------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------- cuda_data_void_ptr:: cuda_data_void_ptr( size_t n ) : num(n) { if (n == 0) return; void* data = nullptr; CHECK_CUDA(cudaMalloc(&data, n)); pdata.reset(data, [](void* ptr){ auto err = cudaFree(ptr); if(err!=cudaSuccess) std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl; }); } // ------------------------------------------------------------------------------------ void memcpy( void* dest, const cuda_data_void_ptr& src, const size_t num ) { DLIB_ASSERT(num <= src.size()); if (src.size() != 0) { CHECK_CUDA(cudaMemcpy(dest, src.data(), num, cudaMemcpyDefault)); } } // ------------------------------------------------------------------------------------ void memcpy( void* dest, const cuda_data_void_ptr& src ) { memcpy(dest, src, src.size()); } // ------------------------------------------------------------------------------------ void memcpy( cuda_data_void_ptr dest, const void* src, const size_t num ) { DLIB_ASSERT(num <= dest.size()); if (dest.size() != 0) { CHECK_CUDA(cudaMemcpy(dest.data(), src, num, cudaMemcpyDefault)); } } // ------------------------------------------------------------------------------------ void memcpy( cuda_data_void_ptr dest, const void* src ) { memcpy(dest,src,dest.size()); } // ------------------------------------------------------------------------------------ class cudnn_device_buffer { public: // not copyable cudnn_device_buffer(const cudnn_device_buffer&) = delete; cudnn_device_buffer& operator=(const cudnn_device_buffer&) = delete; cudnn_device_buffer() { buffers.resize(16); } ~cudnn_device_buffer() { } cuda_data_void_ptr get ( size_t size ) { int new_device_id; CHECK_CUDA(cudaGetDevice(&new_device_id)); // make room for more devices if needed if (new_device_id >= (long)buffers.size()) buffers.resize(new_device_id+16); // If we don't have a buffer already for this device then make one, or if it's too // small, make a bigger one. cuda_data_void_ptr buff = buffers[new_device_id].lock(); if (!buff || buff.size() < size) { buff = cuda_data_void_ptr(size); buffers[new_device_id] = buff; } // Finally, return the buffer for the current device return buff; } private: std::vector buffers; }; // ---------------------------------------------------------------------------------------- cuda_data_void_ptr device_global_buffer(size_t size) { thread_local cudnn_device_buffer buffer; return buffer.get(size); } // ------------------------------------------------------------------------------------ } } #endif // DLIB_USE_CUDA #endif // DLIB_DNN_CuDA_DATA_PTR_CPP_ ================================================ FILE: dlib/cuda/cuda_data_ptr.h ================================================ // Copyright (C) 2017 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuDA_DATA_PTR_H_ #define DLIB_DNN_CuDA_DATA_PTR_H_ #include "../assert.h" #ifdef DLIB_USE_CUDA #include #include #include namespace dlib { namespace cuda { // ------------------------------------------------------------------------------------ class cuda_data_void_ptr; class weak_cuda_data_void_ptr { /*! WHAT THIS OBJECT REPRESENTS This is just like a std::weak_ptr version of cuda_data_void_ptr. It allows you to hold a non-owning reference to a cuda_data_void_ptr. !*/ public: weak_cuda_data_void_ptr() = default; weak_cuda_data_void_ptr(const cuda_data_void_ptr& ptr); void reset() { pdata.reset(); num = 0; } cuda_data_void_ptr lock() const; /*! ensures - if (the memory block referenced by this object hasn't been deleted) then - returns a cuda_data_void_ptr referencing that memory block - else - returns a default initialized cuda_data_void_ptr (i.e. an empty one). !*/ private: size_t num = 0; std::weak_ptr pdata; }; // ---------------------------------------------------------------------------------------- class cuda_data_void_ptr { /*! WHAT THIS OBJECT REPRESENTS This is a block of memory on a CUDA device. !*/ public: cuda_data_void_ptr() = default; cuda_data_void_ptr(size_t n); /*! ensures - This object will allocate a device memory buffer of n bytes. - #size() == n !*/ void* data() { return pdata.get(); } const void* data() const { return pdata.get(); } operator void*() { return pdata.get(); } operator const void*() const { return pdata.get(); } void reset() { pdata.reset(); } size_t size() const { return num; } /*! ensures - returns the length of this buffer, in bytes. !*/ cuda_data_void_ptr operator+ (size_t offset) const /*! requires - offset < size() ensures - returns a pointer that is offset by the given amount. !*/ { DLIB_CASSERT(offset < num); cuda_data_void_ptr temp; temp.num = num-offset; temp.pdata = std::shared_ptr(pdata, ((char*)pdata.get())+offset); return temp; } void shrink(size_t new_size) /*! requires - new_size <= num ensures - #size() == new_size - Doesn't actually deallocate anything, just changes the size() metadata to a smaller number and only for this instance of the pointer. !*/ { DLIB_CASSERT(new_size <= num); num = new_size; } private: friend class weak_cuda_data_void_ptr; size_t num = 0; std::shared_ptr pdata; }; inline cuda_data_void_ptr operator+(size_t offset, const cuda_data_void_ptr& rhs) { return rhs+offset; } // ------------------------------------------------------------------------------------ void memcpy( void* dest, const cuda_data_void_ptr& src ); /*! requires - dest == a pointer to at least src.size() bytes on the host machine. ensures - copies the GPU data from src into dest. - This routine is equivalent to performing: memcpy(dest,src,src.size()) !*/ void memcpy( void* dest, const cuda_data_void_ptr& src, const size_t num ); /*! requires - dest == a pointer to at least num bytes on the host machine. - num <= src.size() ensures - copies the GPU data from src into dest. Copies only the first num bytes of src to dest. !*/ // ------------------------------------------------------------------------------------ void memcpy( cuda_data_void_ptr dest, const void* src ); /*! requires - dest == a pointer to at least src.size() bytes on the host machine. ensures - copies the host data from src to the GPU memory buffer dest. - This routine is equivalent to performing: memcpy(dest,src,dest.size()) !*/ void memcpy( cuda_data_void_ptr dest, const void* src, const size_t num ); /*! requires - dest == a pointer to at least num bytes on the host machine. - num <= dest.size() ensures - copies the host data from src to the GPU memory buffer dest. Copies only the first num bytes of src to dest. !*/ // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ template class cuda_data_ptr { /*! WHAT THIS OBJECT REPRESENTS This is a block of memory on a CUDA device. It is just a type safe version of cuda_data_void_ptr. !*/ public: static_assert(std::is_standard_layout::value, "You can only create basic standard layout types on the GPU"); cuda_data_ptr() = default; cuda_data_ptr(size_t n) : num(n) /*! ensures - This object will allocate a device memory buffer of n T objects. - #size() == n !*/ { if (n == 0) return; pdata = cuda_data_void_ptr(n*sizeof(T)); } cuda_data_ptr( const cuda_data_ptr::type> &other ) : num(other.num), pdata(other.pdata) {} /*! ensures - *this is a copy of other. This version of the copy constructor allows assigning non-const pointers to const ones. For instance, converting from cuda_data_ptr to cuda_data_ptr. !*/ T* data() { return (T*)pdata.data(); } const T* data() const { return (T*)pdata.data(); } operator T*() { return (T*)pdata.data(); } operator const T*() const { return (T*)pdata.data(); } void reset() { pdata.reset(); } size_t size() const { return num; } /*! ensures - returns the number of T instances pointed to by *this. !*/ operator cuda_data_void_ptr() const /*! ensures - returns *this as a cuda_data_void_ptr. Importantly, the returned size() will reflect the number of bytes referenced by *this. To be clear, let P be the returned pointer. Then: - P.get() == get() - P.size() == size() * sizeof(T) !*/ { cuda_data_void_ptr temp = pdata; temp.shrink(size() * sizeof(T)); return temp; } private: template friend cuda_data_ptr static_pointer_cast(const cuda_data_void_ptr &ptr); template friend cuda_data_ptr static_pointer_cast(const cuda_data_void_ptr &ptr, size_t num); template friend class cuda_data_ptr; size_t num = 0; cuda_data_void_ptr pdata; }; template cuda_data_ptr static_pointer_cast(const cuda_data_void_ptr &ptr) { DLIB_CASSERT(ptr.size() % sizeof(T) == 0, "Size of memory buffer in ptr doesn't match sizeof(T). " << "\nptr.size(): "<< ptr.size() << "\nsizeof(T): "<< sizeof(T)); cuda_data_ptr result; result.pdata = ptr; result.num = ptr.size() / sizeof(T); return result; } template cuda_data_ptr static_pointer_cast(const cuda_data_void_ptr &ptr, size_t num) { DLIB_CASSERT(num*sizeof(T) <= ptr.size(), "Size of memory buffer in ptr isn't big enough to represent this many T objects. " << "\nnum: "<< num << "\nnum*sizeof(T): "<< num*sizeof(T) << "\nsizeof(T): "<< sizeof(T) << "\nptr.size(): "<< ptr.size()); cuda_data_ptr result; result.pdata = ptr; result.num = num; return result; } template void memcpy(std::vector& dest, const cuda_data_ptr& src) { dest.resize(src.size()); if (src.size() != 0) memcpy(dest.data(), static_cast(src)); } template void memcpy(cuda_data_ptr& dest, const std::vector& src) { if (src.size() != dest.size()) dest = cuda_data_ptr(src.size()); if (dest.size() != 0) memcpy(static_cast(dest), src.data()); } template void memcpy(cuda_data_ptr& dest, const T* src) { memcpy(static_cast(dest), src); } template void memcpy(cuda_data_ptr& dest, const T* src, size_t num) { DLIB_CASSERT(num <= dest.size()); memcpy(static_cast(dest), src, num*sizeof(T)); } template void memcpy(T* dest, const cuda_data_ptr& src) { memcpy(dest, static_cast(src)); } template void memcpy(T* dest, const cuda_data_ptr& src, size_t num) { DLIB_CASSERT(num <= src.size()); memcpy(dest, static_cast(src), num*sizeof(T)); } // ------------------------------------------------------------------------------------ cuda_data_void_ptr device_global_buffer(size_t size); /*! ensures - Returns a pointer to a globally shared CUDA memory buffer on the currently selected CUDA device. The buffer is also thread local. So each host thread will get its own buffer. You can use this global buffer as scratch space for CUDA computations that all take place on the default stream. Using it in this way ensures that there aren't any race conditions involving the use of the buffer. - The returned pointer will point to at least size bytes. It may point to more. - The global buffer is deallocated once all references to it are destructed. However, if device_global_buffer() is called before then with a size <= the last size requested, then the previously returned global buffer pointer is returned. This avoids triggering expensive CUDA reallocations. So if you want to avoid these reallocations then hold a copy of the pointer returned by this function. However, as a general rule, client code should not hold the returned cuda_data_void_ptr for long durations, but instead should call device_global_buffer() whenever the buffer is needed, and overwrite the previously returned pointer with the new pointer. Doing so ensures multiple buffers are not kept around in the event that multiple sized buffers are requested. To explain this, consider this code, assumed to execute at program startup: auto ptr1 = device_global_buffer(1); auto ptr2 = device_global_buffer(2); auto ptr3 = device_global_buffer(3); since the sizes increased at each call 3 separate buffers were allocated. First one of size 1, then of size 2, then of size 3. If we then executed: ptr1 = device_global_buffer(1); ptr2 = device_global_buffer(2); ptr3 = device_global_buffer(3); all three of these pointers would now point to the same buffer, since the smaller requests can be satisfied by returning the size 3 buffer in each case. !*/ // ---------------------------------------------------------------------------------------- } } #endif // DLIB_USE_CUDA #endif // DLIB_DNN_CuDA_DATA_PTR_H_ ================================================ FILE: dlib/cuda/cuda_dlib.cu ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include "cuda_utils.h" #include "cuda_dlib.h" #include "cudnn_dlibapi.h" #include namespace dlib { namespace cuda { // ----------------------------------------------------------------------------------- void set_device ( int dev ) { CHECK_CUDA(cudaSetDevice(dev)); } int get_device ( ) { int dev = 0; CHECK_CUDA(cudaGetDevice(&dev)); return dev; } std::string get_device_name ( int device ) { cudaDeviceProp props; CHECK_CUDA(cudaGetDeviceProperties(&props, device)); return props.name; } void set_current_device_blocking_sync( ) { CHECK_CUDA(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync)); } int get_num_devices ( ) { int num_devices; CHECK_CUDA(cudaGetDeviceCount(&num_devices)); return num_devices; } bool can_access_peer (int device_id, int peer_device_id) { int can_access; CHECK_CUDA(cudaDeviceCanAccessPeer(&can_access, device_id, peer_device_id)); return can_access != 0; } bool can_access_peer (const tensor& device, const tensor& peer_device) { return can_access_peer(device.device_id(), peer_device.device_id()); } void device_synchronize (int dev) { raii_set_device set_dev(dev); CHECK_CUDA(cudaDeviceSynchronize()); } void device_synchronize (const tensor& dev) { device_synchronize(dev.device_id()); } enable_peer_access:: enable_peer_access( int device_id, int peer_device_id ) : call_disable(false), device_id(device_id), peer_device_id(peer_device_id) { raii_set_device set_dev(device_id); auto err = cudaDeviceEnablePeerAccess(peer_device_id, 0); if (err == cudaSuccess) { call_disable = true; } else if (err == cudaErrorPeerAccessAlreadyEnabled) { // call cudaGetLastError() to dispose of this error since we don't // care. auto err2 = cudaGetLastError(); if (err2 != cudaErrorPeerAccessAlreadyEnabled) CHECK_CUDA(err2); } else { CHECK_CUDA(err); } } enable_peer_access:: ~enable_peer_access() noexcept(false) { if (call_disable) { raii_set_device set_dev(device_id); CHECK_CUDA(cudaDeviceDisablePeerAccess(peer_device_id)); } } // ----------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------- __global__ void _cuda_inverse_norms(float* invnorms, const float* data, size_t nr, size_t nc, const float eps) { // initialize invnorms before we begin. for (auto i : grid_stride_range_y(0, nr)) for (auto j : grid_stride_range(0, 1)) invnorms[i] = eps; __syncthreads(); for (auto i : grid_stride_range_y(0, nr)) { auto p = data + i*nc; float temp = 0; for (auto j : grid_stride_range(0, nc)) temp += p[j]*p[j]; // and store the sum into invnorms[i] warp_reduce_atomic_add(invnorms[i], temp); } __syncthreads(); for (auto i : grid_stride_range_y(0, nr)) for (auto j : grid_stride_range(0, 1)) invnorms[i] = 1.0/std::sqrt(invnorms[i]); } void inverse_norms ( resizable_tensor& invnorms, const tensor& data, const double eps ) { invnorms.set_size(data.num_samples()); launch_kernel(_cuda_inverse_norms, max_jobs(data.size()/data.num_samples(), data.num_samples()), invnorms.device(), data.device(), data.num_samples(), data.size()/data.num_samples(), eps); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_dot_prods(float* out, const float* lhs, const float* rhs, size_t nr, size_t nc) { // initialize out before we begin. for (auto i : grid_stride_range_y(0, nr)) for (auto j : grid_stride_range(0, 1)) out[i] = 0; __syncthreads(); for (auto i : grid_stride_range_y(0, nr)) { auto l = lhs + i*nc; auto r = rhs + i*nc; float temp = 0; for (auto j : grid_stride_range(0, nc)) temp += l[j]*r[j]; // and store the sum into out[i] warp_reduce_atomic_add(out[i], temp); } } __global__ void _cuda_dot_prods_add_to(float* out, const float* lhs, const float* rhs, size_t nr, size_t nc) { for (auto i : grid_stride_range_y(0, nr)) { auto l = lhs + i*nc; auto r = rhs + i*nc; float temp = 0; for (auto j : grid_stride_range(0, nc)) temp += l[j]*r[j]; // and store the sum into out[i] warp_reduce_atomic_add(out[i], temp); } } void dot_prods ( resizable_tensor& out, const tensor& lhs, const tensor& rhs ) { DLIB_CASSERT(have_same_dimensions(lhs,rhs)); out.set_size(lhs.num_samples()); if (out.size() == 0) return; const auto nr = lhs.num_samples(); const auto nc = lhs.size()/lhs.num_samples(); launch_kernel(_cuda_dot_prods, max_jobs(nc,nr), out.device_write_only(), lhs.device(), rhs.device(), nr, nc); } void dot_prods ( bool add_to, tensor& out, const tensor& lhs, const tensor& rhs ) { DLIB_CASSERT(have_same_dimensions(lhs,rhs)); DLIB_CASSERT(out.k() == 1 && out.nr() == 1 && out.nc() == 1); DLIB_CASSERT(out.size() == lhs.num_samples()); const auto nr = lhs.num_samples(); const auto nc = lhs.size()/lhs.num_samples(); if (add_to) launch_kernel(_cuda_dot_prods_add_to, max_jobs(nc,nr), out.device(), lhs.device(), rhs.device(), nr, nc); else launch_kernel(_cuda_dot_prods, max_jobs(nc,nr), out.device_write_only(), lhs.device(), rhs.device(), nr, nc); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_scale_columns(float* out, const float* m, const float* v, size_t nr, size_t nc) { for (auto j : grid_stride_range(0, nr*nc)) { out[j] = m[j]*v[j%nc]; } } void scale_columns ( tensor& out, const tensor& m, const tensor& v ) { launch_kernel(_cuda_scale_columns, max_jobs(m.size()), out.device(), m.device(), v.device(), m.num_samples(), m.size()/m.num_samples()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_scale_rows(float* out, const float* m, const float* v, size_t nr, size_t nc) { for (auto j : grid_stride_range(0, nr*nc)) { out[j] = m[j]*v[j/nc]; } } void scale_rows ( tensor& out, const tensor& m, const tensor& v ) { launch_kernel(_cuda_scale_rows, max_jobs(m.size()), out.device(), m.device(), v.device(), m.num_samples(), m.size()/m.num_samples()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_scale_rows2(float* out, const float* m1, const float* m2, const float* v1, const float* v2, size_t nr, size_t nc) { for (auto j : grid_stride_range(0, nr*nc)) { out[j] = (m1[j] - m2[j]*v1[j/nc]) * v2[j/nc]; } } __global__ void _cuda_scale_rows2_beta(const float beta, float* out, const float* m1, const float* m2, const float* v1, const float* v2, size_t nr, size_t nc) { for (auto j : grid_stride_range(0, nr*nc)) { out[j] = beta*out[j] + (m1[j] - m2[j]*v1[j/nc]) * v2[j/nc]; } } void scale_rows2 ( float beta, tensor& out, const tensor& m1, const tensor& m2, const tensor& v1, const tensor& v2 ) { if (beta == 0) { launch_kernel(_cuda_scale_rows2, max_jobs(m1.size()), out.device(), m1.device(), m2.device(), v1.device(), v2.device(), m1.num_samples(), m1.size()/m1.num_samples()); } else { launch_kernel(_cuda_scale_rows2_beta, max_jobs(m1.size()), beta, out.device(), m1.device(), m2.device(), v1.device(), v2.device(), m1.num_samples(), m1.size()/m1.num_samples()); } } // ---------------------------------------------------------------------------------------- __global__ void _cuda_exp(float* dest, const float* src, size_t n) { for (auto i : grid_stride_range(0, n)) dest[i] = ::exp(src[i]); } void exp ( tensor& dest, const tensor& src ) { DLIB_ASSERT(dest.size() == src.size()); launch_kernel(_cuda_exp, max_jobs(src.size()), dest.device(), src.device(), src.size()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_log(float* dest, const float* src, size_t n) { for (auto i : grid_stride_range(0, n)) dest[i] = ::log(src[i]); } void log ( tensor& dest, const tensor& src ) { DLIB_ASSERT(dest.size() == src.size()); launch_kernel(_cuda_log, max_jobs(src.size()), dest.device(), src.device(), src.size()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_log10(float* dest, const float* src, size_t n) { for (auto i : grid_stride_range(0, n)) dest[i] = ::log10(src[i]); } void log10 ( tensor& dest, const tensor& src ) { DLIB_ASSERT(dest.size() == src.size()); launch_kernel(_cuda_log10, max_jobs(src.size()), dest.device(), src.device(), src.size()); } // ----------------------------------------------------------------------------------- __global__ void _cuda_multiply1(float* d, const float* s1, const float* s2, size_t n) { for (auto i : grid_stride_range(0, n)) { d[i] = s1[i]*s2[i]; } } __global__ void _cuda_multiply2(float* d, const float* s1, const float* s2, size_t n, size_t s1_n, size_t s2_n, size_t max_size) { for (auto i : grid_stride_range(0, n)) { d[i] = 0; for (size_t j = i; j < max_size; j += n) d[i] += s1[j%s1_n]*s2[j%s2_n]; } } __global__ void _cuda_multiply3(float* d, const float* s1, const float* s2, size_t n, size_t s1_n, size_t s2_n) { for (auto i : grid_stride_range(0, n)) { d[i] = s1[i%s1_n]*s2[i%s2_n]; } } __global__ void _cuda_multiply1_add_to(float* d, const float* s1, const float* s2, size_t n) { for (auto i : grid_stride_range(0, n)) { d[i] += s1[i]*s2[i]; } } __global__ void _cuda_multiply2_add_to(float* d, const float* s1, const float* s2, size_t n, size_t s1_n, size_t s2_n, size_t max_size) { for (auto i : grid_stride_range(0, n)) { for (size_t j = i; j < max_size; j += n) d[i] += s1[j%s1_n]*s2[j%s2_n]; } } __global__ void _cuda_multiply3_add_to(float* d, const float* s1, const float* s2, size_t n, size_t s1_n, size_t s2_n) { for (auto i : grid_stride_range(0, n)) { d[i] += s1[i%s1_n]*s2[i%s2_n]; } } void multiply ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ) { DLIB_CASSERT(dest.k() == src1.k() && src1.k() == src2.k() && dest.nr() == src1.nr() && src1.nr() == src2.nr() && dest.nc() == src1.nc() && src1.nc() == src2.nc() ); const long MD = std::max(std::max(dest.num_samples(),src1.num_samples()),src2.num_samples()); DLIB_CASSERT((dest.num_samples()==1 || dest.num_samples()==MD) && (src1.num_samples()==1 || src1.num_samples()==MD) && (src2.num_samples()==1 || src2.num_samples()==MD) ); if (dest.size() == 0) return; const size_t max_size = std::max(std::max(dest.size(),src1.size()),src2.size()); const auto d = dest.host(); const auto s1 = src1.host(); const auto s2 = src2.host(); if (dest.size() == src1.size() && src1.size() == src2.size()) { if (add_to) launch_kernel(_cuda_multiply1_add_to,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), src1.size()); else launch_kernel(_cuda_multiply1,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), src1.size()); } else if (dest.num_samples() == 1) { if (add_to) launch_kernel(_cuda_multiply2_add_to,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), src1.size(), src2.size(), max_size); else launch_kernel(_cuda_multiply2,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), src1.size(), src2.size(), max_size); } else { if (add_to) launch_kernel(_cuda_multiply3_add_to,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), src1.size(), src2.size()); else launch_kernel(_cuda_multiply3,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), src1.size(), src2.size()); } } // ------------------------------------------------------------------------------------ __global__ void _cuda_multiply_conv(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks) { for (auto i : grid_stride_range(0, n)) { auto k = (i/bs)%ks; d[i] = s1[i]*s2[k]; } } __global__ void _cuda_multiply_conv2(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks) { // zero initialize d before we begin. for (auto i : grid_stride_range_y(0, ks)) for (auto j : grid_stride_range(0, 1)) d[i] = 0; __syncthreads(); // loop over all the image planes for (auto i : grid_stride_range_y(0, n)) { // sum all the elements in the i-th image plane float temp = 0; for (auto j : grid_stride_range(i*bs, (i+1)*bs)) temp += s1[j]*s2[j]; auto k = i%ks; // and store the sum into d[k] warp_reduce_atomic_add(d[k], temp); } } __global__ void _cuda_multiply_conv_add_to(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks) { for (auto i : grid_stride_range(0, n)) { auto k = (i/bs)%ks; d[i] += s1[i]*s2[k]; } } __global__ void _cuda_multiply_conv2_add_to(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks) { // loop over all the image planes for (auto i : grid_stride_range_y(0, n)) { // sum all the elements in the i-th image plane float temp = 0; for (auto j : grid_stride_range(i*bs, (i+1)*bs)) temp += s1[j]*s2[j]; auto k = i%ks; // and store the sum into d[k] warp_reduce_atomic_add(d[k], temp); } } void multiply_conv ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ) { if (have_same_dimensions(dest,src1)) { DLIB_CASSERT(src2.num_samples() == 1 && src2.nr() == 1 && src2.nc() == 1 && src2.k() == src1.k()); if (dest.size() == 0) return; if (add_to) launch_kernel(_cuda_multiply_conv_add_to,max_jobs(dest.size()), dest.device(), src1.device(), src1.size(), src2.device(), src1.nr()*src1.nc(), src1.k()); else launch_kernel(_cuda_multiply_conv,max_jobs(dest.size()), dest.device(), src1.device(), src1.size(), src2.device(), src1.nr()*src1.nc(), src1.k()); } else { DLIB_CASSERT(have_same_dimensions(src1,src2)); DLIB_CASSERT(dest.num_samples() == 1 && dest.nr() == 1 && dest.nc() == 1 && dest.k() == src1.k()); if (dest.size() == 0) return; const auto bs = src1.nr()*src1.nc(); const auto n = src1.num_samples()*src1.k(); if (add_to) launch_kernel(_cuda_multiply_conv2_add_to, max_jobs(bs,n), dest.device(), src1.device(), n, src2.device(), bs, src1.k()); else launch_kernel(_cuda_multiply_conv2, max_jobs(bs,n), dest.device(), src1.device(), n, src2.device(), bs, src1.k()); } } // ------------------------------------------------------------------------------------ __global__ void _cuda_scale_channels_add_to(float* d, const float* src, size_t n, const float* scales, size_t bs) { for (auto i : grid_stride_range(0, n)) { auto k = i/bs; d[i] += src[i]*scales[k]; } } __global__ void _cuda_scale_channels(float* d, const float* src, size_t n, const float* scales, size_t bs) { for (auto i : grid_stride_range(0, n)) { auto k = i/bs; d[i] = src[i]*scales[k]; } } void scale_channels ( bool add_to, tensor& dest, const tensor& src, const tensor& scales ) { DLIB_CASSERT(have_same_dimensions(dest,src) && scales.num_samples() == src.num_samples() && scales.k() == src.k() && scales.nr() == 1 && scales.nc() == 1 ); if (dest.size() == 0) return; if (add_to) launch_kernel(_cuda_scale_channels_add_to,max_jobs(dest.size()), dest.device(), src.device(), src.size(), scales.device(), src.nr()*src.nc()); else launch_kernel(_cuda_scale_channels,max_jobs(dest.size()), dest.device_write_only(), src.device(), src.size(), scales.device(), src.nr()*src.nc()); } // ------------------------------------------------------------------------------------ __global__ void _cuda_mult1(float* d, const float* s1, const float* s2, size_t n) { for (auto i : grid_stride_range(0, n)) { d[i] = s1[i]*s2[i]; } } __global__ void _cuda_mult1_add_to(float* d, const float* s1, const float* s2, size_t n) { for (auto i : grid_stride_range(0, n)) { d[i] += s1[i]*s2[i]; } } __global__ void _cuda_mult2(float* d, const float* s1, const float* s2, size_t dn, size_t dk, size_t dr, size_t dc, size_t s1n, size_t s1k, size_t s1r, size_t s1c, size_t s2n, size_t s2k, size_t s2r, size_t s2c) { for (auto i : grid_stride_range(0, dn*dk*dr*dc)) { size_t n,k,r,c; unpack_idx(i, dk,dr,dc, n,k,r,c); float v1 = 0; float v2 = 0; if (n < s1n && k < s1k && r < s1r && c < s1c ) { v1 = s1[pack_idx(s1k,s1r,s1c, n,k,r,c)]; } if (n < s2n && k < s2k && r < s2r && c < s2c ) { v2 = s2[pack_idx(s2k,s2r,s2c, n,k,r,c)]; } d[i] = v1*v2; } } __global__ void _cuda_mult2_add_to(float* d, const float* s1, const float* s2, size_t dn, size_t dk, size_t dr, size_t dc, size_t s1n, size_t s1k, size_t s1r, size_t s1c, size_t s2n, size_t s2k, size_t s2r, size_t s2c) { for (auto i : grid_stride_range(0, dn*dk*dr*dc)) { size_t n,k,r,c; unpack_idx(i, dk,dr,dc, n,k,r,c); float v1 = 0; float v2 = 0; if (n < s1n && k < s1k && r < s1r && c < s1c ) { v1 = s1[pack_idx(s1k,s1r,s1c, n,k,r,c)]; } if (n < s2n && k < s2k && r < s2r && c < s2c ) { v2 = s2[pack_idx(s2k,s2r,s2c, n,k,r,c)]; } d[i] += v1*v2; } } void multiply_zero_padded ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ) { if (dest.size() == 0) return; // Do the simple and fast version if everything has the same dimensions if (have_same_dimensions(dest, src1) && have_same_dimensions(dest, src2)) { if (add_to) launch_kernel(_cuda_mult1_add_to,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.size()); else launch_kernel(_cuda_mult1,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.size()); } else { if (add_to) { // Otherwise, do the more complex version with bounds checking. launch_kernel(_cuda_mult2_add_to,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.num_samples(), dest.k(), dest.nr(), dest.nc(), src1.num_samples(), src1.k(), src1.nr(), src1.nc(), src2.num_samples(), src2.k(), src2.nr(), src2.nc() ); } else { // Otherwise, do the more complex version with bounds checking. launch_kernel(_cuda_mult2,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.num_samples(), dest.k(), dest.nr(), dest.nc(), src1.num_samples(), src1.k(), src1.nr(), src1.nc(), src2.num_samples(), src2.k(), src2.nr(), src2.nc() ); } } } // ------------------------------------------------------------------------------------ __global__ void _cuda_add1(float* d, const float* s1, const float* s2, size_t n) { for (auto i : grid_stride_range(0, n)) { d[i] = s1[i]+s2[i]; } } __global__ void _cuda_add2(float* d, const float* s1, const float* s2, size_t dn, size_t dk, size_t dr, size_t dc, size_t s1n, size_t s1k, size_t s1r, size_t s1c, size_t s2n, size_t s2k, size_t s2r, size_t s2c) { for (auto i : grid_stride_range(0, dn*dk*dr*dc)) { size_t n,k,r,c; unpack_idx(i, dk,dr,dc, n,k,r,c); float v1 = 0; float v2 = 0; if (n < s1n && k < s1k && r < s1r && c < s1c ) { v1 = s1[pack_idx(s1k,s1r,s1c, n,k,r,c)]; } if (n < s2n && k < s2k && r < s2r && c < s2c ) { v2 = s2[pack_idx(s2k,s2r,s2c, n,k,r,c)]; } d[i] = v1+v2; } } void add ( tensor& dest, const tensor& src1, const tensor& src2 ) { if (dest.size() == 0) return; // Do the simple and fast version if everything has the same dimensions if (have_same_dimensions(dest, src1) && have_same_dimensions(dest, src2)) { launch_kernel(_cuda_add1,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.size()); } else { // Otherwise, do the more complex version with bounds checking. launch_kernel(_cuda_add2,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.num_samples(), dest.k(), dest.nr(), dest.nc(), src1.num_samples(), src1.k(), src1.nr(), src1.nc(), src2.num_samples(), src2.k(), src2.nr(), src2.nc() ); } } // ------------------------------------------------------------------------------------ __global__ void _cuda_affine_transform1(float* d, const float* s, size_t n, float A, float B) { for (auto i : grid_stride_range(0, n)) { d[i] = A*s[i] + B; } } __global__ void _cuda_affine_transform1_0(float* d, const float* s, size_t n, float A) { for (auto i : grid_stride_range(0, n)) { d[i] = A*s[i]; } } void affine_transform( tensor& dest, const tensor& src, const float A, const float B ) { DLIB_CASSERT(dest.size()==src.size()); if (B != 0) launch_kernel(_cuda_affine_transform1,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A, B); else launch_kernel(_cuda_affine_transform1_0,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A); } void affine_transform( tensor& dest, const tensor& src, const float A ) { DLIB_CASSERT(dest.size()==src.size()); launch_kernel(_cuda_affine_transform1_0,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_affine_transform_rect( float* d, const float* s1, const float* s2, const float* s3, float A, float B, float C, size_t start_idx, size_t n, size_t rect_nc, size_t total_nc ) { for (auto i : grid_stride_range(0, n)) { size_t r = i/rect_nc; size_t c = i%rect_nc; size_t idx = r*total_nc + c + start_idx; d[idx] = A*s1[idx] + B*s2[idx] + C*s3[idx]; } } void affine_transform( const rectangle& rect, tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, float A, float B, float C ) { DLIB_CASSERT(dest.size() == src1.size()); DLIB_CASSERT(dest.size() == src2.size()); DLIB_CASSERT(dest.size() == src3.size()); DLIB_CASSERT(dest.num_samples() == src1.num_samples()); DLIB_CASSERT(dest.num_samples() == src2.num_samples()); DLIB_CASSERT(dest.num_samples() == src3.num_samples()); DLIB_CASSERT(rectangle(0,0, dest.size()/dest.num_samples()-1, dest.num_samples()-1).contains(rect)); launch_kernel(_cuda_affine_transform_rect,max_jobs(rect.area()), dest.device(), src1.device(), src2.device(), src3.device(), A, B, C, rect.left() + rect.top()*(dest.size()/dest.num_samples()), rect.area(), rect.width(), dest.size()/dest.num_samples()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_affine_transform4(float* d, const float* s1, const float* s2, size_t n, float A, float B, float C) { for (auto i : grid_stride_range(0, n)) { d[i] = A*s1[i] + B*s2[i] + C; } } __global__ void _cuda_affine_transform4_0(float* d, const float* s1, const float* s2, size_t n, float A, float B) { for (auto i : grid_stride_range(0, n)) { d[i] = A*s1[i] + B*s2[i]; } } void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const float A, const float B, const float C ) { DLIB_CASSERT(dest.size()==src1.size()); DLIB_CASSERT(dest.size()==src2.size()); if (C != 0) launch_kernel(_cuda_affine_transform4,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B, C); else launch_kernel(_cuda_affine_transform4_0,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B); } void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const float A, const float B ) { DLIB_CASSERT(dest.size()==src1.size()); DLIB_CASSERT(dest.size()==src2.size()); launch_kernel(_cuda_affine_transform4_0,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_add_scaled(float* d, const float* s, size_t n, float scale) { for (auto i : grid_stride_range(0, n)) { d[i] += scale*s[i]; } } void add_scaled( tensor& dest, const float scale, const tensor& src ) { DLIB_CASSERT(dest.size()==src.size()); launch_kernel(_cuda_add_scaled,max_jobs(dest.size()),dest.device(), src.device(), dest.size(), scale); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_add_cv_to_all_columns(float beta, float* dest, float alpha, const float* src, size_t size, size_t stride) { for (auto i : grid_stride_range(0, size)) { dest[i] = beta*dest[i] + alpha*src[i/stride]; } } __global__ void _cuda_add_cv_to_all_columns_no_beta(float* dest, float alpha, const float* src, size_t size, size_t stride) { for (auto i : grid_stride_range(0, size)) { dest[i] = alpha*src[i/stride]; } } void add_cv_to_all_columns( float beta, tensor& dest, float alpha, const tensor& src ) { DLIB_CASSERT(dest.num_samples() == src.num_samples() && src.num_samples() == src.size()); if (beta == 0) launch_kernel(_cuda_add_cv_to_all_columns_no_beta, max_jobs(dest.size()), dest.device(), alpha, src.device(), dest.size(), dest.size()/dest.num_samples()); else launch_kernel(_cuda_add_cv_to_all_columns, max_jobs(dest.size()), beta, dest.device(), alpha, src.device(), dest.size(), dest.size()/dest.num_samples()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_affine_transform5( float* d, const float* s1, const float* s2, const float* s3, size_t n, float A, float B, float C, float D ) { for (auto i : grid_stride_range(0, n)) { d[i] = A*s1[i] + B*s2[i] + C*s3[i] + D; } } void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C, const float D ) { DLIB_CASSERT(dest.size()==src1.size()); DLIB_CASSERT(dest.size()==src2.size()); DLIB_CASSERT(dest.size()==src3.size()); launch_kernel(_cuda_affine_transform5,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), src3.device(), dest.size(), A, B, C, D); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_affine_transform_range( float* d, const float* s1, const float* s2, const float* s3, size_t begin, size_t end, float A, float B, float C ) { for (auto i : grid_stride_range(begin, end)) { d[i] = A*s1[i] + B*s2[i] + C*s3[i]; } } void affine_transform_range( size_t begin, size_t end, tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C ) { DLIB_CASSERT(dest.size()==src1.size()); DLIB_CASSERT(dest.size()==src2.size()); DLIB_CASSERT(dest.size()==src3.size()); DLIB_CASSERT(begin <= end && end <= dest.size()); launch_kernel(_cuda_affine_transform_range,max_jobs(end-begin), dest.device(), src1.device(), src2.device(), src3.device(), begin, end, A, B, C); } // ----------------------------------------------------------------------------------- __global__ void _cuda_affine_transform2(float* d, const float* s, size_t n, const float* A, const float* B) { for (auto i : grid_stride_range(0, n)) { d[i] = A[i]*s[i] + B[i]; } } __global__ void _cuda_affine_transform3(float* d, const float* s, size_t n, const float* A, const float* B, size_t bs) { for (auto i : grid_stride_range(0, n)) { d[i] = A[i%bs]*s[i] + B[i%bs]; } } void affine_transform( tensor& dest, const tensor& src, const tensor& A, const tensor& B ) { DLIB_CASSERT(have_same_dimensions(dest, src)); DLIB_CASSERT( ((A.num_samples()==1 && B.num_samples()==1) || (A.num_samples()==src.num_samples() && B.num_samples()==src.num_samples()))); DLIB_CASSERT( A.nr()==B.nr() && B.nr()==src.nr() && A.nc()==B.nc() && B.nc()==src.nc() && A.k() ==B.k() && B.k()==src.k(), "\nA.nr(): " << A.nr() << "\nB.nr(): " << B.nr() << "\nsrc.nr(): " << src.nr() <<"\nA.nc(): " << A.nc() << "\nB.nc(): " << B.nc() << "\nsrc.nc(): " << src.nc() <<"\nA.k(): " << A.k() << "\nB.k(): " << B.k() << "\nsrc.k(): " << src.k() ); if (A.num_samples() == 1) { launch_kernel(_cuda_affine_transform3,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A.device(), B.device(), A.size()); } else { launch_kernel(_cuda_affine_transform2,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A.device(), B.device()); } } // ---------------------------------------------------------------------------------------- __global__ void _cuda_compute_adam_update( size_t begin, size_t end, float* s, float* m, float* v, const float alpha, const float weight_decay, const float momentum1, const float momentum2, const float* params, const float* params_grad ) { const float eps = 1e-8; // The loop is equivalent to doing this: // m = momentum1*m + (1-momentum1) * (weight_decay*params + params_grad); // v = momentum2*v + (1-momentum2)*squared(weight_decay*params + params_grad); // s = -alpha*m/(sqrt(v) + eps); for (auto i : grid_stride_range(begin, end)) { float g = (weight_decay*params[i] + params_grad[i]); m[i] = momentum1*m[i] + (1-momentum1)*g; v[i] = momentum2*v[i] + (1-momentum2)*g*g; s[i] = -alpha*m[i]/(std::sqrt(v[i]) + eps); } } void compute_adam_update ( size_t begin, size_t end, tensor& s, tensor& m, tensor& v, const float t, const float learning_rate, const float weight_decay, const float momentum1, const float momentum2, const tensor& params, const tensor& params_grad ) { DLIB_CASSERT(s.size() == m.size() && s.size() == v.size() && s.size() == params.size() && s.size() == params_grad.size()); DLIB_CASSERT(begin <= end && end <= params.size()); const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t)); launch_kernel(_cuda_compute_adam_update,max_jobs(end-begin), begin, end, s.device(), m.device(), v.device(), alpha, weight_decay, momentum1, momentum2, params.device(), params_grad.device()); } // ----------------------------------------------------------------------------------- __global__ void _cuda_affine_transform_conv(float* d, const float* s, size_t n, const float* A, const float* B, size_t bs, size_t ks) { for (auto i : grid_stride_range(0, n)) { auto k = (i/bs)%ks; d[i] = A[k]*s[i] + B[k]; } } void affine_transform_conv( tensor& dest, const tensor& src, const tensor& A, const tensor& B ) { DLIB_CASSERT(have_same_dimensions(dest, src)); DLIB_CASSERT(have_same_dimensions(A, B)); DLIB_CASSERT(A.num_samples() == 1 && A.nr() == 1 && A.nc() == 1 && A.k() == src.k()); launch_kernel(_cuda_affine_transform_conv,max_jobs(dest.size()), dest.device(), src.device(), src.size(), A.device(), B.device(), src.nr()*src.nc(), src.k()); } // ----------------------------------------------------------------------------------- __global__ void _add_bias_gradient(float* out, const float* in, size_t n, size_t total_n) { for (auto i : grid_stride_range(0, n)) { out[i] = in[i]; for (size_t j = i+n; j < total_n; j+=n) out[i] += in[j]; } } void assign_bias_gradient ( tensor& grad, const tensor& gradient_input ) { DLIB_CASSERT( grad.num_samples() == 1 && gradient_input.k() == grad.k() && gradient_input.nr() == grad.nr() && gradient_input.nc() == grad.nc() && gradient_input.size() > 0); launch_kernel(_add_bias_gradient,max_jobs(grad.size()),grad.device(), gradient_input.device(), grad.size(), gradient_input.size()); } // ---------------------------------------------------------------------------------------- __global__ void _set_tensor(float* out, size_t n, const float val) { for (auto i : grid_stride_range(0, n)) out[i] = val; } void set_tensor ( tensor& t, float value ) { launch_kernel(_set_tensor, max_jobs(t.size()), t.device(), t.size(), value); } // ---------------------------------------------------------------------------------------- __global__ void _scale_tensor(float* out, size_t n, const float val) { for (auto i : grid_stride_range(0, n)) out[i] *= val; } void scale_tensor ( tensor& t, float value ) { launch_kernel(_scale_tensor, max_jobs(t.size()), t.device(), t.size(), value); } // ----------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------- __global__ void _cuda_threshold(float* d, size_t n, float thresh) { for (auto i : grid_stride_range(0, n)) { d[i] = d[i]>thresh ? 1:0; } } void threshold ( tensor& data, float thresh ) { launch_kernel(_cuda_threshold,max_jobs(data.size()),data.device(), data.size(), thresh); } // ------------------------------------------------------------------------------------ __global__ void _cuda_dot(const float* a, const float* b, size_t n, float* result) { // Parallel sum everything into local temp variables. float temp = 0; for(auto i : grid_stride_range(0, n)) temp += a[i]*b[i]; // Then do the warp reduce add thing to merge into one output value. warp_reduce_atomic_add(*result, temp); } void dot ( const tensor& a, const tensor& b, tensor& result, size_t idx ) { DLIB_CASSERT(a.size() == b.size()); DLIB_CASSERT(idx < result.size()); launch_kernel(_cuda_dot, max_jobs(a.size()), a.device(), b.device(), a.size(), result.device()+idx); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_prelu(const float* s, float* d, size_t n, const float* pp) { const float p = *pp; for (auto i : grid_stride_range(0, n)) { if (s[i] > 0) d[i] = s[i]; else d[i] = p*s[i]; } } void prelu ( tensor& dest, const tensor& src, const tensor& param ) { launch_kernel(_cuda_prelu, max_jobs(dest.size()), src.device(), dest.device(), src.size(), param.device()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_prelu_gradient(float* out, const float* s, const float* gi, size_t n, const float* pp, float* ppgrad) { const float p = *pp; float pgrad = 0; for(auto i : grid_stride_range(0, n)) { if (s[i] > 0) { out[i] += gi[i]; } else { out[i] += p*gi[i]; pgrad += gi[i]*s[i]; } } // Then do the warp reduce add thing to merge into one output value. warp_reduce_atomic_add(*ppgrad, pgrad); } void prelu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input, const tensor& param, tensor& params_grad ) { params_grad = 0; launch_kernel(_cuda_prelu_gradient, max_jobs(grad.size()), grad.device(), src.device(), gradient_input.device(), grad.size(), param.device(), params_grad.device()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_leaky_relu(const float* s, float* d, size_t n, const float alpha) { for (auto i : grid_stride_range(0, n)) { if (s[i] > 0) d[i] = s[i]; else d[i] = alpha * s[i]; } } void leaky_relu( tensor& dest, const tensor& src, const float alpha ) { launch_kernel(_cuda_leaky_relu, max_jobs(dest.size()), src.device(), dest.device(), src.size(), alpha); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_leaky_relu_gradient_inplace(float* out, const float* s, const float* gi, size_t n, const float alpha) { for (auto i : grid_stride_range(0, n)) { if (s[i] > 0) out[i] = gi[i]; else out[i] = alpha * gi[i]; } } __global__ void _cuda_leaky_relu_gradient(float* out, const float* s, const float* gi, size_t n, const float alpha) { for (auto i : grid_stride_range(0, n)) { if (s[i] > 0) out[i] += gi[i]; else out[i] += alpha * gi[i]; } } void leaky_relu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input, const float alpha ) { float* out = grad.device(); const float* gi = gradient_input.device(); if (out == gi) { launch_kernel(_cuda_leaky_relu_gradient_inplace, max_jobs(grad.size()), out, src.device(), gi, grad.size(), alpha); } else { launch_kernel(_cuda_leaky_relu_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size(), alpha); } } // ---------------------------------------------------------------------------------------- __global__ void _cuda_mish(const float* s, float* d, size_t n) { for (auto i : grid_stride_range(0, n)) { const auto e = std::exp(s[i]); const auto delta = 2*e + e*e + 2; d[i] = s[i] - 2*s[i]/delta; } } void mish ( tensor& dest, const tensor& src ) { launch_kernel(_cuda_mish, max_jobs(dest.size()), src.device(), dest.device(), src.size()); } // ---------------------------------------------------------------------------------------- __device__ float mish_compute_gradient(float x) { if (x >= 8) return 1.f; if (x <= -8) return 0.f; const auto e = std::exp(x); const auto delta = 2*e + e*e + 2; const auto omega = 4*(x + 1) + 4*e*e + e*e*e + e*(4*x + 6); return e*omega/(delta*delta); } __global__ void _cuda_mish_gradient_inplace(float* out, const float* s, const float* gi, size_t n) { for (auto i : grid_stride_range(0, n)) out[i] = gi[i]*mish_compute_gradient(s[i]); } __global__ void _cuda_mish_gradient(float* out, const float* s, const float* gi, size_t n) { for (auto i : grid_stride_range(0, n)) out[i] += gi[i]*mish_compute_gradient(s[i]); } void mish_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ) { float* out = grad.device(); const float* gi = gradient_input.device(); if (out == gi) launch_kernel(_cuda_mish_gradient_inplace, max_jobs(grad.size()), out, src.device(), gi, grad.size()); else launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_clipped_relu(const float* s, float* d, size_t n, const float alpha) { for (auto i : grid_stride_range(0, n)) { if (s[i] < 0) d[i] = 0; else if (s[i] > alpha) d[i] = alpha; else d[i] = s[i]; } } void clipped_relu ( tensor& dest, const tensor &src, const float alpha ) { launch_kernel(_cuda_clipped_relu, max_jobs(dest.size()), src.device(), dest.device(), src.size(), alpha); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_clipped_relu_gradient_inplace(float* out, const float* s, const float* gi, size_t n, const float alpha) { for (auto i : grid_stride_range(0, n)) { if (s[i] > 0 && s[i] < alpha) out[i] = gi[i]; else out[i] = 0.f; } } __global__ void _cuda_clipped_relu_gradient(float* out, const float* s, const float* gi, size_t n, const float alpha) { for (auto i : grid_stride_range(0, n)) { if (s[i] > 0 && s[i] < alpha) out[i] += gi[i]; } } void clipped_relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float alpha ) { float* out = grad.device(); const float* gi = gradient_input.device(); if (out == gi) launch_kernel(_cuda_clipped_relu_gradient_inplace, max_jobs(grad.size()), out, dest.device(), gi, grad.size(), alpha); else launch_kernel(_cuda_clipped_relu_gradient, max_jobs(grad.size()), out, dest.device(), gi, grad.size(), alpha); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_elu(const float* s, float* d, size_t n, const float alpha) { for (auto i : grid_stride_range(0, n)) { if (s[i] > 0) d[i] = s[i]; else d[i] = alpha * (std::exp(s[i]) - 1.0f); } } void elu ( tensor& dest, const tensor &src, const float alpha ) { launch_kernel(_cuda_elu, max_jobs(dest.size()), src.device(), dest.device(), src.size(), alpha); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_elu_gradient_inplace(float* out, const float* s, const float* gi, size_t n, const float alpha) { for (auto i : grid_stride_range(0, n)) { if (s[i] > 0) out[i] = gi[i]; else out[i] = (alpha + s[i]) * gi[i]; } } __global__ void _cuda_elu_gradient(float* out, const float* s, const float* gi, size_t n, const float alpha) { for (auto i : grid_stride_range(0, n)) { if (s[i] > 0) out[i] += gi[i]; else out[i] += (alpha + s[i]) * gi[i]; } } void elu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float alpha ) { float* out = grad.device(); const float* gi = gradient_input.device(); if (out == gi) launch_kernel(_cuda_elu_gradient_inplace, max_jobs(grad.size()), out, dest.device(), gi, grad.size(), alpha); else launch_kernel(_cuda_elu_gradient, max_jobs(grad.size()), out, dest.device(), gi, grad.size(), alpha); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_gelu(const float* s, float* d, size_t n) { for (auto i : grid_stride_range(0, n)) { d[i] = s[i] * normcdf(s[i]); } } void gelu ( tensor& dest, const tensor& src ) { launch_kernel(_cuda_gelu, max_jobs(dest.size()), src.device(), dest.device(), src.size()); } // ---------------------------------------------------------------------------------------- __device__ float gelu_compute_gradient(float x) { const float beta = 1.0f / CUDART_SQRT_2PI; const float cdf = normcdf(x); const float pdf = beta*std::exp(-0.5f*x*x); return cdf + x * pdf; } __global__ void _cuda_gelu_gradient_inplace(float* out, const float* s, const float* gi, size_t n) { for (auto i : grid_stride_range(0, n)) out[i] = gi[i]*gelu_compute_gradient(s[i]); } __global__ void _cuda_gelu_gradient(float* out, const float* s, const float* gi, size_t n) { for (auto i : grid_stride_range(0, n)) out[i] += gi[i]*gelu_compute_gradient(s[i]); } void gelu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ) { float* out = grad.device(); const float* gi = gradient_input.device(); if (out == gi) launch_kernel(_cuda_gelu_gradient_inplace, max_jobs(grad.size()), out, src.device(), gi, grad.size()); else launch_kernel(_cuda_gelu_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_smelu (const float* s, float* d, size_t n, const float beta) { for (auto i : grid_stride_range(0, n)) { if (s[i] >= beta) d[i] = s[i]; else if (s[i] <= -beta) d[i] = 0; else d[i] = (s[i] + beta) * (s[i] + beta) / (4 * beta); } } void smelu ( tensor& dest, const tensor& src, const float beta ) { launch_kernel(_cuda_smelu, max_jobs(dest.size()), src.device(), dest.device(), src.size(), beta); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_smelu_gradient_inplace(float* out, const float* s, const float* gi, size_t n, const float beta) { for (auto i : grid_stride_range(0, n)) { if (s[i] >= beta) out[i] = gi[i]; else if (s[i] == 0) out[i] = 0; else out[i] = std::sqrt(beta * s[i]) / beta * gi[i]; } } __global__ void _cuda_smelu_gradient(float* out, const float* s, const float* gi, size_t n, const float beta) { for (auto i : grid_stride_range(0, n)) { if (s[i] >= beta) out[i] += gi[i]; else if (s[i] == 0) continue; else out[i] += std::sqrt(beta * s[i]) / beta * gi[i]; } } void smelu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input, const float beta ) { float* out = grad.device(); const float* gi = gradient_input.device(); if (out == gi) { launch_kernel(_cuda_smelu_gradient_inplace, max_jobs(grad.size()), out, src.device(), gi, grad.size(), beta); } else { launch_kernel(_cuda_smelu_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size(), beta); } } // ---------------------------------------------------------------------------------------- __global__ void _cuda_silu(const float* s, float* d, size_t n) { for (auto i : grid_stride_range(0, n)) { d[i] = s[i] / (1.0f + std::exp(-s[i])); } } void silu ( tensor& dest, const tensor& src ) { launch_kernel(_cuda_silu, max_jobs(dest.size()), src.device(), dest.device(), src.size()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_silu_gradient_inplace(float* out, const float* s, const float* gi, size_t n) { for (auto i : grid_stride_range(0, n)) { const auto sig_s = 1.0f / (1.0f + std::exp(-s[i])); out[i] = gi[i] * (sig_s * (1.0f + s[i] * (1.0f - sig_s))); } } __global__ void _cuda_silu_gradient(float* out, const float* s, const float* gi, size_t n) { for (auto i : grid_stride_range(0, n)) { const auto sig_s = 1.0f / (1.0f + std::exp(-s[i])); out[i] += gi[i] * (sig_s * (1.0f + s[i] * (1.0f - sig_s))); } } void silu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ) { float* out = grad.device(); const float* gi = gradient_input.device(); if (out == gi) launch_kernel(_cuda_silu_gradient_inplace, max_jobs(grad.size()), out, src.device(), gi, grad.size()); else launch_kernel(_cuda_silu_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_resize_bilinear(size_t dsize, size_t dchan_size, size_t dnc, float* d, size_t schan_size, int snr, int snc, const float* s, const float x_scale, const float y_scale) { for(auto i : grid_stride_range(0, dsize)) { const int idx = i%dchan_size; const int channel = i/dchan_size; const int sidx = channel*schan_size; const int r = idx/dnc; const int c = idx%dnc; const float y = r*y_scale; const int top = static_cast(::floorf(y)); const int bottom = ::min(top+1, snr-1); const float tb_frac = y - top; const float x = c*x_scale; const int left = static_cast(::floorf(x)); const int right = ::min(left+1, snc-1); const float lr_frac = x - left; float tl = s[sidx+top*snc+left]; float tr = s[sidx+top*snc+right]; float bl = s[sidx+bottom*snc+left]; float br = s[sidx+bottom*snc+right]; float temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) + tb_frac*((1-lr_frac)*bl + lr_frac*br); d[i] = temp; } } __global__ void _cuda_resize_bilinear_strided(size_t dsize, size_t dchan_size, size_t dnc, float* d, size_t schan_size, int snr, int snc, const float* s, const float x_scale, const float y_scale, size_t dest_row_stride, size_t src_row_stride, size_t dest_chan_size_strided ) { for(auto i : grid_stride_range(0, dsize)) { const int idx = i%dchan_size; const int channel = i/dchan_size; const int sidx = channel*schan_size; const int r = idx/dnc; const int c = idx%dnc; const int didx = channel*dest_chan_size_strided + r*dest_row_stride+c; const float y = r*y_scale; const int top = static_cast(::floorf(y)); const int bottom = ::min(top+1, snr-1); const float tb_frac = y - top; const float x = c*x_scale; const int left = static_cast(::floorf(x)); const int right = ::min(left+1, snc-1); const float lr_frac = x - left; float tl = s[sidx+top*src_row_stride+left]; float tr = s[sidx+top*src_row_stride+right]; float bl = s[sidx+bottom*src_row_stride+left]; float br = s[sidx+bottom*src_row_stride+right]; float temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) + tb_frac*((1-lr_frac)*bl + lr_frac*br); d[didx] = temp; } } void resize_bilinear ( tensor& dest, long long dest_row_stride, long long dest_channel_stride, const tensor& src, long long src_row_stride, long long src_channel_stride ) { DLIB_CASSERT(is_same_object(dest, src)==false); DLIB_CASSERT(dest.num_samples() == src.num_samples()); DLIB_CASSERT(dest.k() == src.k()); if (dest.size() == 0 || src.size() == 0) return; const float x_scale = (src.nc()-1)/(float)std::max((dest.nc()-1),1); const float y_scale = (src.nr()-1)/(float)std::max((dest.nr()-1),1); if (dest.nc() == dest_row_stride && dest.nr()*dest.nc()==dest_channel_stride && src.nc() == src_row_stride && src.nr()*src.nc()==src_channel_stride) { launch_kernel(_cuda_resize_bilinear, dest.size(), dest.nr()*dest.nc(), dest.nc(), dest.device(), src.nr()*src.nc(), src.nr(), src.nc(), src.device(), x_scale, y_scale); } else { launch_kernel(_cuda_resize_bilinear_strided, dest.size(), dest.nr()*dest.nc(), dest.nc(), dest.device(), src_channel_stride, src.nr(), src.nc(), src.device(), x_scale, y_scale, dest_row_stride, src_row_stride, dest_channel_stride); } } // ---------------------------------------------------------------------------------------- __global__ void _cuda_resize_bilinear_gradient(size_t dsize, size_t dchan_size, size_t dnc, const float* d, size_t schan_size, int snr, int snc, float* s, const float x_scale, const float y_scale) { for(auto i : grid_stride_range(0, dsize)) { const float tmp = d[i]; const int idx = i%dchan_size; const int channel = i/dchan_size; const int sidx = channel*schan_size; const int r = idx/dnc; const int c = idx%dnc; const float y = r*y_scale; const int top = static_cast(::floorf(y)); const int bottom = ::min(top+1, snr-1); const float tb_frac = y - top; const float x = c*x_scale; const int left = static_cast(::floorf(x)); const int right = ::min(left+1, snc-1); const float lr_frac = x - left; atomicAdd(s+sidx+top*snc+left, tmp*(1-tb_frac)*(1-lr_frac)); atomicAdd(s+sidx+top*snc+right, tmp*(1-tb_frac)*(lr_frac)); atomicAdd(s+sidx+bottom*snc+left, tmp*(tb_frac)*(1-lr_frac)); atomicAdd(s+sidx+bottom*snc+right, tmp*(tb_frac)*(lr_frac)); } } __global__ void _cuda_resize_bilinear_gradient_strided(size_t dsize, size_t dchan_size, size_t dnc, const float* d, size_t schan_size, int snr, int snc, float* s, const float x_scale, const float y_scale, size_t dest_row_stride, size_t src_row_stride, size_t dest_chan_size_strided ) { for(auto i : grid_stride_range(0, dsize)) { const int idx = i%dchan_size; const int channel = i/dchan_size; const int didx = channel*dest_chan_size_strided; const int sidx = channel*schan_size; const int r = idx/dnc; const int c = idx%dnc; const float tmp = d[didx + r*dest_row_stride+c]; const float y = r*y_scale; const int top = static_cast(::floorf(y)); const int bottom = ::min(top+1, snr-1); const float tb_frac = y - top; const float x = c*x_scale; const int left = static_cast(::floorf(x)); const int right = ::min(left+1, snc-1); const float lr_frac = x - left; atomicAdd(s+sidx+top*src_row_stride+left, tmp*(1-tb_frac)*(1-lr_frac)); atomicAdd(s+sidx+top*src_row_stride+right, tmp*(1-tb_frac)*(lr_frac)); atomicAdd(s+sidx+bottom*src_row_stride+left, tmp*(tb_frac)*(1-lr_frac)); atomicAdd(s+sidx+bottom*src_row_stride+right, tmp*(tb_frac)*(lr_frac)); } } void resize_bilinear_gradient ( tensor& grad, long long grad_row_stride, long long grad_channel_stride, const tensor& gradient_input, long long gradient_input_row_stride, long long gradient_input_channel_stride ) { DLIB_CASSERT(is_same_object(grad, gradient_input)==false); DLIB_CASSERT(gradient_input.num_samples() == grad.num_samples()); DLIB_CASSERT(gradient_input.k() == grad.k()); if (grad.size() == 0 || gradient_input.size() == 0) return; const float x_scale = (grad.nc()-1)/(float)std::max((gradient_input.nc()-1),1); const float y_scale = (grad.nr()-1)/(float)std::max((gradient_input.nr()-1),1); if (grad.nc() == grad_row_stride && grad.nr()*grad.nc()==grad_channel_stride && gradient_input.nc() == gradient_input_row_stride && gradient_input.nr()*gradient_input.nc()==gradient_input_channel_stride) { launch_kernel(_cuda_resize_bilinear_gradient, gradient_input.size(), gradient_input.nr()*gradient_input.nc(), gradient_input.nc(), gradient_input.device(), grad.nr()*grad.nc(), grad.nr(), grad.nc(), grad.device(), x_scale, y_scale); } else { launch_kernel(_cuda_resize_bilinear_gradient_strided, gradient_input.size(), gradient_input.nr()*gradient_input.nc(), gradient_input.nc(), gradient_input.device(), grad_channel_stride, grad.nr(), grad.nc(), grad.device(), x_scale, y_scale, gradient_input_row_stride, grad_row_stride, gradient_input_channel_stride); } } // ---------------------------------------------------------------------------------------- __global__ void _cuda_reorg(size_t dsize, size_t dk, size_t dnr, size_t dnc, float* d, size_t sk, size_t snr, int snc, const float* s, const size_t row_stride, const size_t col_stride, const bool add_to) { const auto out_plane_size = dnr * dnc; const auto out_sample_size = dk * out_plane_size; for (auto i : grid_stride_range(0, dsize)) { const auto n = i / out_sample_size; const auto out_idx = i % out_sample_size; const auto out_k = out_idx / out_plane_size; const auto out_rc = out_idx % out_plane_size; const auto out_r = out_rc / dnc; const auto out_c = out_rc % dnc; const auto in_k = out_k % sk; const auto in_r = out_r * row_stride + (out_k / sk) / col_stride; const auto in_c = out_c * col_stride + (out_k / sk) % col_stride; const auto in_idx = ((n * sk + in_k) * snr + in_r) * snc + in_c; if (add_to) d[i] += s[in_idx]; else d[i] = s[in_idx]; } } __global__ void _cuda_reorg_gradient(size_t ssize, size_t dk, size_t dnr, size_t dnc, float* d, size_t sk, size_t snr, int snc, const float* s, const size_t row_stride, const size_t col_stride, const bool add_to ) { for(auto i : grid_stride_range(0, ssize)) { const auto n = i / (sk * snr * snc); const auto sample_idx = i % (sk * snr * snc); const auto in_k = (sample_idx / (snr * snc)) % sk; const auto in_r = (sample_idx / snc) % snr; const auto in_c = sample_idx % snc; const auto out_k = in_k % dk; const auto out_r = in_r * row_stride + (in_k / dk) / col_stride; const auto out_c = in_c * col_stride + (in_k / dk) % col_stride; const auto out_idx = ((n * dk + out_k) * dnr + out_r) * dnc + out_c; if (add_to) d[out_idx] += s[i]; else d[out_idx] = s[i]; } } void reorg( bool add_to, tensor& dest, const int row_stride, const int col_stride, const tensor& src ) { DLIB_CASSERT(!is_same_object(dest, src), "Destination and source must be distinct objects."); DLIB_CASSERT(src.nr() % row_stride == 0, "The number of rows in src must be divisible by row_stride."); DLIB_CASSERT(src.nc() % col_stride == 0, "The number of columns in src must be divisible by col_stride."); DLIB_CASSERT(dest.num_samples() == src.num_samples(), "The number of samples must match."); DLIB_CASSERT(dest.k() == src.k() * row_stride * col_stride, "The number of channels must match."); DLIB_CASSERT(dest.nr() == src.nr() / row_stride, "The number of rows must match."); DLIB_CASSERT(dest.nc() == src.nc() / col_stride, "The number of columns must match."); launch_kernel(_cuda_reorg, dest.size(), dest.k(), dest.nr(), dest.nc(), dest.device(), src.k(), src.nr(), src.nc(), src.device(), row_stride, col_stride, add_to); } void reorg_gradient( bool add_to, tensor& grad, const int row_stride, const int col_stride, const tensor& gradient_input ) { DLIB_CASSERT(!is_same_object(grad, gradient_input), "Grad and gradient_input must be distinct objects."); DLIB_CASSERT(grad.nr() % row_stride == 0, "The number of rows in grad must be divisible by row_stride."); DLIB_CASSERT(grad.nc() % col_stride == 0, "The number of columns in grad must be divisible by col_stride."); DLIB_CASSERT(grad.num_samples() == gradient_input.num_samples(), "The number of samples in grad and gradient_input must match."); DLIB_CASSERT(grad.k() == gradient_input.k() / row_stride / col_stride, "The number of channels in grad must be gradient_input.k() divided by row_stride and col_stride."); DLIB_CASSERT(grad.nr() == gradient_input.nr() * row_stride, "The number of rows in grad must be gradient_input.nr() multiplied by row_stride."); DLIB_CASSERT(grad.nc() == gradient_input.nc() * col_stride, "The number of columns in grad must be gradient_input.nc() multiplied by col_stride."); launch_kernel(_cuda_reorg_gradient, gradient_input.size(), grad.k(), grad.nr(), grad.nc(), grad.device(), gradient_input.k(), gradient_input.nr(), gradient_input.nc(), gradient_input.device(), row_stride, col_stride, add_to); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_embeddings(size_t dsize, size_t dk, size_t dr, size_t dc, float* d, const float* s, const float* e, size_t es ) { for (auto i : grid_stride_range(0, dsize)) { const auto n = i / (dk * dr * dc); const auto s_idx = i % (dk * dr * dc); const auto k = (s_idx / (dr * dc)) % dk; const auto r = (s_idx / dc) % dr; const auto c = s_idx % dc; const unsigned long t_idx = static_cast(s[(n * dk + k) * dr + r]); if (t_idx < es) d[i] = e[t_idx * dc + c]; else d[i] = 0.0f; } } void embeddings( resizable_tensor& dest, const tensor& src, const tensor& embs ) { DLIB_CASSERT( src.nr() > 0 && embs.num_samples() > 0 && embs.k() > 0 && embs.nr() == 1 && embs.nc() == 1, "\nsrc.num_samples(): " << src.num_samples() << "\nsrc.k(): " << src.k() << "\nsrc.nr(): " << src.nr() << "\nsrc.nc(): " << src.nc() << "\nembs.num_samples(): " << embs.num_samples() << "\nembs.k(): " << embs.k() << "\nembs.nr(): " << embs.nr() << "\nembs.nc(): " << embs.nc() ); const long dk = dest.k(); const long dr = dest.nr(); const long dc = dest.nc(); launch_kernel(_cuda_embeddings, dest.size(), dk, dr, dc, dest.device(), src.device(), embs.device(), embs.num_samples()); } __global__ void _cuda_embeddings_gradient(size_t ssize, size_t sk, size_t sr, size_t sc, const float* o, const float* gi, float* g, const float* f, float lr, bool sl, size_t es ) { for (auto i : grid_stride_range(0, ssize)) { const auto n = i / (sk * sr * sc); const auto s_idx = i % (sk * sr * sc); const auto k = (s_idx / (sr * sc)) % sk; const auto r = (s_idx / sc) % sr; const auto c = s_idx % sc; const unsigned long t_idx = static_cast(o[(n * sk + k) * sr + r]); if (t_idx < es) { const float f_t = f[t_idx]; float f_s = 1.0f; if (sl && f_t != 0.0f) f_s = fminf(0.15f, fmaxf(1.0f / f_t, 1.0f)); if (f_t > 1) atomicAdd(&g[t_idx * sc + c], -gi[i] * lr * f_s); else g[t_idx * sc + c] -= gi[i] * lr * f_s; } } } void embeddings_gradient( const tensor& prev, const tensor& gradient_input, tensor& grads, const tensor& freqs, float learning_rate, bool scale ) { DLIB_CASSERT( prev.nr() > 0 && gradient_input.num_samples() == prev.num_samples() && gradient_input.k() == prev.k() && gradient_input.nr() == prev.nr() && gradient_input.nc() == grads.k() && grads.num_samples() > 0 && grads.k() > 0 && grads.nr() == 1 && grads.nc() == 1, "\ngradient_input.num_samples(): " << gradient_input.num_samples() << "\ngradient_input.k(): " << gradient_input.k() << "\ngradient_input.nr(): " << gradient_input.nr() << "\ngradient_input.nc(): " << gradient_input.nc() << "\nprev.num_samples(): " << prev.num_samples() << "\nprev.k(): " << prev.k() << "\nprev.nr(): " << prev.nr() << "\nprev.nc(): " << prev.nc() << "\ngrads.num_samples(): " << grads.num_samples() << "\ngrads.k(): " << grads.k() << "\ngrads.nr(): " << grads.nr() << "\ngrads.nc(): " << grads.nc() ); const long sk = gradient_input.k(); const long sr = gradient_input.nr(); const long sc = gradient_input.nc(); launch_kernel(_cuda_embeddings_gradient, gradient_input.size(), sk, sr, sc, prev.device(), gradient_input.device(), grads.device(), freqs.device(), learning_rate, scale, grads.num_samples()); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_layer_normalize( float* out, const float* s, float* m, float* v, const float* g, const float* b, float eps, size_t ns, size_t k, size_t num ) { // compute means and sum of squares for (auto n : grid_stride_range_y(0, ns)) { const auto ps = s + n * k * num; float means = 0; float invstds = 0; for (auto i : grid_stride_range(0, k * num)) { means += ps[i]; invstds += ps[i] * ps[i]; } warp_reduce_atomic_add(m[n], means / (k * num)); warp_reduce_atomic_add(v[n], invstds / (k * num)); } __syncthreads(); // compute variances for (auto n : grid_stride_range_y(0, ns)) { for (auto i : grid_stride_range(0, 1)) { v[n] = 1.0f / std::sqrt(v[n] - m[n] * m[n] + eps); } } __syncthreads(); for (auto n : grid_stride_range_y(0, ns)) { const auto ps = s + n * k * num; const auto pout = out + n * k * num; for (auto i : grid_stride_range(0, k * num)) { pout[i] = (ps[i] - m[n]) * v[n]; pout[i] = pout[i] * g[i / num] + b[i / num]; } } } void layer_normalize ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const tensor& src, const tensor& gamma, const tensor& beta ) { const long num = src.nr() * src.nc(); DLIB_CASSERT( have_same_dimensions(gamma, beta) && gamma.k() == src.k() && gamma.nr() == 1 && gamma.nc() == 1 && eps > 0, "\nsrc.k(): " << src.k() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\nbeta.k(): " << beta.k() << "\nbeta.nr(): " << beta.nr() << "\nbeta.nc(): " << beta.nc() << "\neps: " << eps ); dest.copy_size(src); means.set_size(src.num_samples()); invstds.set_size(src.num_samples()); means = 0; invstds = 0; launch_kernel(_cuda_layer_normalize, max_jobs(src.k() * num, src.num_samples()), dest.device(), src.device(), means.device(), invstds.device(), gamma.device(), beta.device(), eps, src.num_samples(), src.k(), num); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_layer_normalize_gradient( float* out, float* gg, float* bg, const float* s, const float* gi, const float* m, const float* v, const float* g, float* dm, float* dv, float eps, size_t ns, size_t ks, size_t num) { for (auto nk : grid_stride_range_y(0, ns * ks)) { const auto n = nk / ks; const auto k = nk % ks; const auto ps = s + (n * ks + k) * num; const auto pgi = gi + (n * ks + k) * num; const float invstd_pow = -0.5 * std::pow(v[n], 3.0f); float temp_bg = 0; float temp_gg = 0; float temp_dv = 0; for (auto i : grid_stride_range(0, num)) { const float x_hat = (ps[i] - m[n]) * v[n]; const float dx = pgi[i] * g[i / num]; temp_bg += pgi[i]; temp_gg += pgi[i] * x_hat; temp_dv += dx * (ps[i] - m[n]) * invstd_pow; } warp_reduce_atomic_add(bg[k], temp_bg); warp_reduce_atomic_add(gg[k], temp_gg); warp_reduce_atomic_add(dv[n], temp_dv); } __syncthreads(); const float invnum = 1.0f / (ks * num); for (auto n : grid_stride_range_y(0, ns)) { const auto ps = s + n * ks * num; const auto pgi = gi + n * ks * num; float temp_dm = 0; for (auto i : grid_stride_range(0, ks * num)) { const float dx = pgi[i] * g[i / num]; temp_dm += -dx * v[n] + dv[n] * -2 * (ps[i] - m[n]) * invnum; } warp_reduce_atomic_add(dm[n], temp_dm); } __syncthreads(); for (auto n : grid_stride_range_y(0, ns)) { const auto ps = s + n * ks * num; const auto pgi = gi + n * ks * num; const auto pout = out + n * ks * num; for (auto i : grid_stride_range(0, ks * num)) { const float dx = pgi[i] * g[i / num]; pout[i] += dx * v[n] + dv[n] * 2 * (ps[i] - m[n]) * invnum + dm[n] * invnum; } } } void layer_normalize_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad, resizable_tensor& dmeans, resizable_tensor& dvars ) { const long num = src.nr() * src.nc(); DLIB_CASSERT(src.num_samples() == means.size()); DLIB_CASSERT(src.num_samples() == invstds.size()); DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad)); DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad)); DLIB_CASSERT(gamma.k() == src.k()); DLIB_CASSERT(gamma.nr() == 1); DLIB_CASSERT(gamma.nc() == 1); DLIB_CASSERT(have_same_dimensions(gradient_input, src)); DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); DLIB_CASSERT(eps > 0); beta_grad = 0; gamma_grad = 0; dvars.copy_size(invstds); dmeans.copy_size(means); dvars = 0; dmeans = 0; launch_kernel(_cuda_layer_normalize_gradient, max_jobs(src.k() * num, src.num_samples()), src_grad.device(), gamma_grad.device(), beta_grad.device(), src.device(), gradient_input.device(), means.device(), invstds.device(), gamma.device(), dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_rms_normalize( float* dest, float* scale, const float* src, const float* gamma, float eps, size_t ns, size_t ks, size_t num ) { for (auto n : grid_stride_range_y(0, ns)) { const auto ps = src + n * ks * num; float sum_squares = 0.0f; for (auto i : grid_stride_range(0, ks * num)) { sum_squares += ps[i] * ps[i]; } warp_reduce_atomic_add(scale[n], sum_squares / (ks * num)); } __syncthreads(); for (auto n : grid_stride_range_y(0, ns)) { for (auto i : grid_stride_range(0, 1)) { scale[n] = 1.0f / std::sqrt(scale[n] + eps); } } __syncthreads(); for (auto n : grid_stride_range_y(0, ns)) { const auto ps = src + n * ks * num; const auto pd = dest + n * ks * num; for (auto i : grid_stride_range(0, ks * num)) { pd[i] = ps[i] * scale[n] * gamma[i / num]; } } } void rms_normalize( const double eps, resizable_tensor& dest, resizable_tensor& scale, const tensor& src, const tensor& gamma ) { DLIB_CASSERT( gamma.k() == src.k() && gamma.nr() == 1 && gamma.nc() == 1 && eps > 0, "\nsrc.k(): " << src.k() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\neps: " << eps ); const long ns = src.num_samples(); const long ks = src.k(); const long num = src.nr() * src.nc(); dest.copy_size(src); scale.set_size(ns); scale = 0; launch_kernel(_cuda_rms_normalize, max_jobs(ks * num, ns), dest.device(), scale.device(), src.device(), gamma.device(), eps, ns, ks, num); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_rms_normalize_gradient( float* src_grad, float* gamma_grad, float* dscale, const float* src, const float* gradient_input, const float* scale, const float* gamma, size_t ns, size_t ks, size_t num ) { for (auto nk : grid_stride_range_y(0, ns * ks)) { const auto n = nk / ks; const auto k = nk % ks; const auto ps = src + (n * ks + k) * num; const auto pgi = gradient_input + (n * ks + k) * num; const float scale_pow = -0.5f * std::pow(scale[n], 3.0f); float temp_gg = 0.0f; float temp_ds = 0.0f; for (auto i : grid_stride_range(0, num)) { const float x_hat = ps[i] * scale[n]; const float dx = pgi[i] * gamma[i / num]; temp_gg += pgi[i] * x_hat; temp_ds += dx * ps[i] * scale_pow; } warp_reduce_atomic_add(gamma_grad[k], temp_gg); warp_reduce_atomic_add(dscale[n], temp_ds); } __syncthreads(); const float invnum = 1.0f / (ks * num); for (auto n : grid_stride_range_y(0, ns)) { const auto ps = src + n * ks * num; const auto pgi = gradient_input + n * ks * num; const auto psg = src_grad + n * ks * num; for (auto i : grid_stride_range(0, ks * num)) { const float dx = pgi[i] * gamma[i / num]; psg[i] += dx * scale[n] + dscale[n] * 2 * ps[i] * invnum; } } } void rms_normalize_gradient( const tensor& gradient_input, const tensor& scale, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, resizable_tensor& dscale ) { DLIB_CASSERT(src.num_samples() == scale.size()); DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad)); DLIB_CASSERT(gamma.k() == src.k()); DLIB_CASSERT(gamma.nr() == 1); DLIB_CASSERT(gamma.nc() == 1); DLIB_CASSERT(have_same_dimensions(gradient_input, src)); DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); const long ns = src.num_samples(); const long ks = src.k(); const long num = src.nr() * src.nc(); gamma_grad = 0; dscale.copy_size(scale); dscale = 0; // Lancement du kernel CUDA launch_kernel(_cuda_rms_normalize_gradient, max_jobs(ks * num, ns), src_grad.device(), gamma_grad.device(), dscale.device(), src.device(), gradient_input.device(), scale.device(), gamma.device(), ns, ks, num); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_copy_tensor_add_to (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size) { for(auto i : grid_stride_range(0, size)) { size_t blk = i/block_size; size_t j = i%block_size; dest[blk*dest_stride + j] += src[blk*src_stride + j]; } } __global__ void _cuda_copy_tensor (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size) { for(auto i : grid_stride_range(0, size)) { size_t blk = i/block_size; size_t j = i%block_size; dest[blk*dest_stride + j] = src[blk*src_stride + j]; } } void copy_tensor( bool add_to, tensor& dest, size_t dest_k_offset, const tensor& src, size_t src_k_offset, size_t count_k ) { const size_t dest_sample_size = static_cast(dest.nc() * dest.nr() * dest.k()); const size_t src_sample_size = static_cast(src.nc() * src.nr() * src.k()); const size_t block_size = count_k * dest.nc() * dest.nr(); DLIB_CASSERT(dest.num_samples() == src.num_samples() && dest.nc() == src.nc() && dest.nr() == src.nr(), "All sources should fit into dest tensor size"); DLIB_CASSERT(dest.k() - dest_k_offset >= count_k, "Not enough space in dest tensor"); DLIB_CASSERT(src.k() - src_k_offset >= count_k, "Not enough space in src tensor"); float* dest_p = dest.device() + dest_k_offset * dest.nc() * dest.nr(); const float* src_p = src.device() + src_k_offset * src.nc() * src.nr();; if (add_to) { launch_kernel(_cuda_copy_tensor_add_to, max_jobs(dest.size()), dest_p, block_size*dest.num_samples(), src_p, dest_sample_size, src_sample_size, block_size); } else { launch_kernel(_cuda_copy_tensor, max_jobs(dest.size()), dest_p, block_size*dest.num_samples(), src_p, dest_sample_size, src_sample_size, block_size); } } __global__ void _cuda_copy_strided_tensor_add_to (float* dest, const float* src, size_t ns, size_t nk, size_t nr, size_t nc, size_t dk, size_t dr, size_t dc, size_t sk, size_t sr, size_t sc) { for(auto i : grid_stride_range(0, ns*nk*nr*nc)) { size_t n,k,r,c; unpack_idx(i, nk,nr,nc, n,k,r,c); dest[pack_idx(dk,dr,dc, n,k,r,c)] += src[pack_idx(sk,sr,sc, n,k,r,c)]; } } __global__ void _cuda_copy_strided_tensor (float* dest, const float* src, size_t ns, size_t nk, size_t nr, size_t nc, size_t dk, size_t dr, size_t dc, size_t sk, size_t sr, size_t sc) { for(auto i : grid_stride_range(0, ns*nk*nr*nc)) { size_t n,k,r,c; unpack_idx(i, nk,nr,nc, n,k,r,c); dest[pack_idx(dk,dr,dc, n,k,r,c)] = src[pack_idx(sk,sr,sc, n,k,r,c)]; } } void copy_tensor( bool add_to, tensor& dest, size_t dk, size_t dnr, size_t dnc, const tensor& src, size_t sk, size_t snr, size_t snc, size_t k, size_t nr, size_t nc ) { DLIB_CASSERT(dest.num_samples() == src.num_samples(), "All sources should fit into dest tensor size"); DLIB_CASSERT(dest.k() - dk >= k && dest.nr() - dnr >= nr && dest.nc() - dnc >= nc, "Not enough space in dest tensor"); DLIB_CASSERT(src.k() - sk >= k && src.nr() - snr >= nr && src.nc() - snc >= nc, "Not enough space in src tensor"); float* dest_p = dest.device() + dk * static_cast(dest.nc() * dest.nr()) \ + dnr * static_cast(dest.nc()) \ + dnc; const float* src_p = src.device() + sk * static_cast(src.nc() * src.nr()) \ + snr * static_cast(src.nc()) \ + snc; if (add_to) { launch_kernel(_cuda_copy_strided_tensor_add_to, max_jobs(dest.size()), dest_p, src_p, dest.num_samples(), k, nr, nc, dest.k(), dest.nr(), dest.nc(), src.k(), src.nr(), src.nc()); } else { launch_kernel(_cuda_copy_strided_tensor, max_jobs(dest.size()), dest_p, src_p, dest.num_samples(), k, nr, nc, dest.k(), dest.nr(), dest.nc(), src.k(), src.nr(), src.nc()); } } // ---------------------------------------------------------------------------------------- __global__ void _cuda_transpose(size_t dsize, size_t dk, size_t dnr, size_t dnc, float* d, size_t sk, size_t snr, int snc, const float* s, const bool add_to) { const auto plane_size = dnr * dnc; const auto sample_size = dk * plane_size; for (auto i : grid_stride_range(0, dsize)) { const auto n = i / sample_size; const auto idx = i % plane_size; const auto in_k = (i / plane_size) % dk; const auto in_r = idx % dnc; const auto in_c = idx / dnc; const auto in_idx = ((n * sk + in_k) * snr + in_r) * snc + in_c; if (add_to) d[i] += s[in_idx]; else d[i] = s[in_idx]; } } void transpose( bool add_to, tensor& dest, const tensor& src ) { DLIB_CASSERT(is_same_object(dest, src) == false); DLIB_CASSERT(dest.num_samples() == src.num_samples() && dest.k() == src.k() && dest.nr() == src.nc() && dest.nc() == src.nr(), "Incompatible tensor dimensions."); launch_kernel(_cuda_transpose, max_jobs(dest.size()), dest.size(), dest.k(), dest.nr(), dest.nc(), dest.device(), src.k(), src.nr(), src.nc(), src.device(), add_to); } // ---------------------------------------------------------------------------------------- // CUDA Kernels for ACT operations __global__ void _cuda_compute_act_halt_probabilities( float* halt_probs, float* logits, const float* input_data, const float* W_halt, float b_halt, size_t batch_size, size_t seq_len, size_t d_model, size_t num_channels, size_t feature_dim ) { const long total_positions = batch_size * seq_len; for (auto pos : grid_stride_range_y(0, total_positions)) for (auto i : grid_stride_range(0, 1)) logits[pos] = b_halt; __syncthreads(); for (auto pos : grid_stride_range_y(0, total_positions)) { const long n = pos / seq_len; const long s = pos % seq_len; float temp = 0; for (auto feat_idx : grid_stride_range(0, feature_dim)) { const long c = feat_idx / d_model; const long d = feat_idx % d_model; const long in_idx = ((n * num_channels + c) * seq_len + s) * d_model + d; temp += input_data[in_idx] * W_halt[feat_idx]; } warp_reduce_atomic_add(logits[pos], temp); } __syncthreads(); for (auto pos : grid_stride_range(0, total_positions)) { halt_probs[pos] = 1.0f / (1.0f + expf(-logits[pos])); } } void compute_act_halt_probabilities( resizable_tensor& halt_probs, resizable_tensor& logits, const tensor& input_data, const tensor& halt_params, long batch_size, long seq_len, long feature_dim ) { const long total_positions = batch_size * seq_len; const long d_model = feature_dim / input_data.k(); const long num_channels = input_data.k(); halt_probs.set_size(total_positions, 1, 1, 1); logits.set_size(total_positions, 1, 1, 1); launch_kernel(_cuda_compute_act_halt_probabilities, max_jobs(feature_dim, total_positions), halt_probs.device(), logits.device(), input_data.device(), halt_params.device(), halt_params.host()[feature_dim], batch_size, seq_len, d_model, num_channels, feature_dim); } __global__ void _cuda_update_act_state( float* output, const float* input_data, const float* halt_probs, float* cumulative_halting, float* remainders, float* n_steps, float* effective_weights, size_t batch_size, size_t seq_len, size_t d_model, size_t num_channels, float halt_threshold, long current_step ) { for (auto pos : grid_stride_range(0, batch_size * seq_len)) { if (cumulative_halting[pos] < halt_threshold) { const size_t n = pos / seq_len; const size_t s = pos % seq_len; float p = halt_probs[pos]; float r = remainders[pos]; float effective = fminf(p * r, halt_threshold - cumulative_halting[pos]); cumulative_halting[pos] += effective; remainders[pos] -= effective; n_steps[pos] = static_cast(current_step + 1); effective_weights[pos] += effective; for (size_t c = 0; c < num_channels; ++c) { for (size_t d = 0; d < d_model; ++d) { const size_t idx = ((n * num_channels + c) * seq_len + s) * d_model + d; output[idx] += effective * input_data[idx]; } } } } } void update_act_state( resizable_tensor& output, const tensor& input_data, const tensor& halt_probs, resizable_tensor& cumulative_halting, resizable_tensor& remainders, resizable_tensor& n_steps, resizable_tensor& effective_weights, long batch_size, long seq_len, long d_model, long num_channels, float halt_threshold, long current_step ) { const long total_positions = batch_size * seq_len; launch_kernel(_cuda_update_act_state, max_jobs(total_positions), output.device(), input_data.device(), halt_probs.device(), cumulative_halting.device(), remainders.device(), n_steps.device(), effective_weights.device(), batch_size, seq_len, d_model, num_channels, halt_threshold, current_step); } __global__ void _cuda_finalize_act_output( float* output, const float* input_data, const float* remainders, float* effective_weights, size_t batch_size, size_t seq_len, size_t d_model, size_t num_channels ) { for (auto pos : grid_stride_range(0, batch_size * seq_len)) { float r = remainders[pos]; if (r > 1e-6f) { const size_t n = pos / seq_len; const size_t s = pos % seq_len; effective_weights[pos] += r; for (size_t c = 0; c < num_channels; ++c) { for (size_t d = 0; d < d_model; ++d) { const size_t idx = ((n * num_channels + c) * seq_len + s) * d_model + d; output[idx] += r * input_data[idx]; } } } } } void finalize_act_output( resizable_tensor& output, const tensor& input_data, const tensor& remainders, resizable_tensor& effective_weights, long batch_size, long seq_len, long d_model, long num_channels ) { const long total_positions = batch_size * seq_len; launch_kernel(_cuda_finalize_act_output, max_jobs(total_positions), output.device(), input_data.device(), remainders.device(), effective_weights.device(), batch_size, seq_len, d_model, num_channels); } __global__ void _cuda_apply_act_depth_scaling( float* gradients, const float* n_steps, size_t batch_size, size_t seq_len, size_t d_model, size_t num_channels, float max_steps, float scale_factor ) { const long total_positions = batch_size * seq_len; const long feature_dim = num_channels * d_model; for (auto pos : grid_stride_range_y(0, total_positions)) { const long n = pos / seq_len; const long s = pos % seq_len; const float scale = 1.0f + scale_factor * (n_steps[pos] / max_steps); for (auto feat_idx : grid_stride_range(0, feature_dim)) { const long c = feat_idx / d_model; const long d = feat_idx % d_model; const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d; gradients[idx] *= scale; } } } void apply_act_depth_scaling( tensor& gradients, const tensor& n_steps, long batch_size, long seq_len, long d_model, long num_channels, float max_steps, float scale_factor ) { const long total_positions = batch_size * seq_len; const long feature_dim = num_channels * d_model; launch_kernel(_cuda_apply_act_depth_scaling, max_jobs(feature_dim, total_positions), gradients.device(), n_steps.device(), batch_size, seq_len, d_model, num_channels, max_steps, scale_factor); } // ---------------------------------------------------------------------------------------- __device__ float cuda_log1pexp(float x) { if (x <= -18) return std::exp(x); else if (-18 < x && x <= 9) return std::log1pf(std::exp(x)); else if (9 < x && x <= 16) return x + expf(-x); else return x; } __global__ void _cuda_compute_loss_binary_log_per_pixel(float* loss_out, float* g, const float* truth, const float* out_data, size_t n, const float scale) { float loss = 0; for(auto i : grid_stride_range(0, n)) { const float y = truth[i]; if (y > 0.f) { const float temp = cuda_log1pexp(-out_data[i]); loss += y*temp; g[i] = y*scale*(g[i]-1); } else if (y < 0.f) { const float temp = -(-out_data[i]-cuda_log1pexp(-out_data[i])); loss += -y*temp; g[i] = -y*scale*g[i]; } else { g[i] = 0.f; } } warp_reduce_atomic_add(*loss_out, loss); } // ---------------------------------------------------------------------------------------- __device__ float cuda_safe_log(float x, float epsilon = 1e-10) { // Prevent trying to calculate the logarithm of a very small number (let alone zero) if (x >= epsilon) return ::log(x); else return ::log(epsilon); } __global__ void _cuda_compute_loss_multiclass_log_per_pixel(float* loss_out, float* g, const uint16_t* truth, size_t n, size_t plane_size, size_t sample_size, size_t nk, uint16_t label_to_ignore, const float scale) { float loss = 0; for(auto i : grid_stride_range(0, n)) { const size_t k = (i/plane_size)%nk; const size_t idx = (i%plane_size) + plane_size*(i/sample_size); const size_t y = truth[idx]; if (k == y) { loss -= cuda_safe_log(g[i]); g[i] = scale*(g[i] - 1); } else if (y == label_to_ignore) { g[i] = 0.f; } else { g[i] = scale*g[i]; } } warp_reduce_atomic_add(*loss_out, loss); } __global__ void _cuda_compute_loss_multiclass_log_per_pixel_weighted(float* loss_out, float* g, const uint16_t* truth, size_t n, size_t plane_size, size_t sample_size, size_t nk, const float* weights, const float scale) { float loss = 0; for(auto i : grid_stride_range(0, n)) { const size_t k = (i/plane_size)%nk; const size_t idx = (i%plane_size) + plane_size*(i/sample_size); const size_t y = truth[idx]; const float weight = weights[idx]; if (k == y) { loss -= weight*cuda_safe_log(g[i]); g[i] = weight*scale*(g[i] - 1); } else { g[i] = weight*scale*g[i]; } } warp_reduce_atomic_add(*loss_out, loss); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_compute_loss_mean_squared_per_channel_and_pixel(float* loss_out, float* g, const float* truth, const float* out_data, size_t n, const float scale) { float loss = 0; for (auto i : grid_stride_range(0, n)) { const float y = truth[i]; const float temp = y - out_data[i]; loss += temp * temp; g[i] = -temp * scale; } warp_reduce_atomic_add(*loss_out, loss); } // ---------------------------------------------------------------------------------------- void compute_loss_binary_log_per_pixel:: do_work( cuda_data_ptr loss_work_buffer, cuda_data_ptr truth_buffer, const tensor& subnetwork_output, tensor& gradient, double& loss ) { CHECK_CUDA(cudaMemset(loss_work_buffer, 0, sizeof(float))); sigmoid(gradient, subnetwork_output); // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. const double scale = 1.0 / (subnetwork_output.num_samples() * subnetwork_output.nr() * subnetwork_output.nc()); launch_kernel(_cuda_compute_loss_binary_log_per_pixel, max_jobs(gradient.size()), loss_work_buffer.data(), gradient.device(), truth_buffer.data(), subnetwork_output.device(), gradient.size(), scale); float floss; dlib::cuda::memcpy(&floss, loss_work_buffer); loss = scale*floss; } void compute_loss_multiclass_log_per_pixel:: do_work( cuda_data_ptr loss_work_buffer, cuda_data_ptr truth_buffer, const tensor& subnetwork_output, tensor& gradient, double& loss ) { CHECK_CUDA(cudaMemset(loss_work_buffer, 0, sizeof(float))); softmax(gradient, subnetwork_output); static const uint16_t label_to_ignore = std::numeric_limits::max(); // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. const double scale = 1.0 / (subnetwork_output.num_samples() * subnetwork_output.nr() * subnetwork_output.nc()); launch_kernel(_cuda_compute_loss_multiclass_log_per_pixel, max_jobs(gradient.size()), loss_work_buffer.data(), gradient.device(), truth_buffer.data(), gradient.size(), gradient.nr()*gradient.nc(), gradient.nr()*gradient.nc()*gradient.k(), gradient.k(), label_to_ignore, scale); float floss; dlib::cuda::memcpy(&floss, loss_work_buffer); loss = scale*floss; } void compute_loss_multiclass_log_per_pixel_weighted:: do_work( cuda_data_ptr loss_work_buffer, cuda_data_ptr truth_buffer, cuda_data_ptr weights_buffer, const tensor& subnetwork_output, tensor& gradient, double& loss ) { CHECK_CUDA(cudaMemset(loss_work_buffer, 0, sizeof(float))); softmax(gradient, subnetwork_output); // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. const double scale = 1.0 / (subnetwork_output.num_samples() * subnetwork_output.nr() * subnetwork_output.nc()); launch_kernel(_cuda_compute_loss_multiclass_log_per_pixel_weighted, max_jobs(gradient.size()), loss_work_buffer.data(), gradient.device(), truth_buffer.data(), gradient.size(), gradient.nr()*gradient.nc(), gradient.nr()*gradient.nc()*gradient.k(), gradient.k(), weights_buffer.data(), scale); float floss; dlib::cuda::memcpy(&floss, loss_work_buffer); loss = scale*floss; } void compute_loss_mean_squared_per_channel_and_pixel:: do_work( cuda_data_ptr loss_work_buffer, cuda_data_ptr truth_buffer, const tensor& subnetwork_output, tensor& gradient, double& loss ) { CHECK_CUDA(cudaMemset(loss_work_buffer, 0, sizeof(float))); // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. const double scale = 1.0 / (subnetwork_output.num_samples() * subnetwork_output.k() * subnetwork_output.nr() * subnetwork_output.nc()); launch_kernel(_cuda_compute_loss_mean_squared_per_channel_and_pixel , max_jobs(gradient.size()), loss_work_buffer.data(), gradient.device(), truth_buffer.data(), subnetwork_output.device(), gradient.size(), scale); float floss; dlib::cuda::memcpy(&floss, loss_work_buffer); loss = scale*floss; } // ---------------------------------------------------------------------------------------- } } ================================================ FILE: dlib/cuda/cuda_dlib.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuDA_H_ #define DLIB_DNN_CuDA_H_ #include "tensor.h" #include "../geometry/rectangle.h" #include "../dnn/utilities.h" namespace dlib { namespace cuda { // ---------------------------------------------------------------------------------------- void set_device ( int dev ); int get_device ( ); int get_num_devices ( ); std::string get_device_name ( int device ); void set_current_device_blocking_sync( ); bool can_access_peer (int device_id, int peer_device_id); bool can_access_peer (const tensor& device, const tensor& peer_device); void device_synchronize (int dev); void device_synchronize (const tensor& dev); class raii_set_device { public: raii_set_device() = delete; raii_set_device(const raii_set_device&) = delete; raii_set_device& operator=(const raii_set_device&) = delete; raii_set_device(int dev) { prev_dev = get_device(); set_device(dev); } raii_set_device(const tensor& dev) { prev_dev = get_device(); set_device(dev.device_id()); } void operator() (int dev) { set_device(dev); } void operator() (const tensor& dev) { set_device(dev.device_id()); } ~raii_set_device() noexcept(false) { set_device(prev_dev); } private: int prev_dev; }; #ifdef DLIB_USE_CUDA class enable_peer_access { public: enable_peer_access() = delete; enable_peer_access(const enable_peer_access&) = delete; enable_peer_access& operator=(const enable_peer_access&) = delete; enable_peer_access( int device_id, int peer_device_id ); enable_peer_access( const tensor& device, const tensor& peer_device ) : enable_peer_access(device.device_id(), peer_device.device_id()) {} ~enable_peer_access() noexcept(false); private: bool call_disable; int device_id; int peer_device_id; }; // ----------------------------------------------------------------------------------- void inverse_norms ( resizable_tensor& invnorms, const tensor& data, const double eps ); void dot_prods ( resizable_tensor& out, const tensor& lhs, const tensor& rhs ); void dot_prods ( bool add_to, tensor& out, const tensor& lhs, const tensor& rhs ); void scale_columns ( tensor& out, const tensor& m, const tensor& v ); void scale_rows ( tensor& out, const tensor& m, const tensor& v ); void scale_rows2 ( float beta, tensor& out, const tensor& m1, const tensor& m2, const tensor& v1, const tensor& v2 ); void exp ( tensor& dest, const tensor& src ); void log ( tensor& dest, const tensor& src ); void log10 ( tensor& dest, const tensor& src ); // ------------------------------------------------------------------------------------ void set_tensor ( tensor& t, float value ); void scale_tensor ( tensor& t, float value ); // ------------------------------------------------------------------------------------ void multiply ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ); void multiply_conv ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ); void multiply_zero_padded ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ); void scale_channels ( bool add_to, tensor& dest, const tensor& src, const tensor& scales ); void add ( tensor& dest, const tensor& src1, const tensor& src2 ); // ----------------------------------------------------------------------------------- void affine_transform( tensor& dest, const tensor& src, const float A, const float B ); void affine_transform( tensor& dest, const tensor& src, const float A ); void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const float A, const float B, const float C ); void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const float A, const float B ); void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C, const float D ); void affine_transform_range( size_t begin, size_t end, tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C ); void affine_transform( const rectangle& rect, tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, float A, float B, float C ); // Note that this function isn't in the tt:: namespace because add_scaled() is // called by cuda::add() so we don't need a tt:: version of add_scaled(). void add_scaled( tensor& dest, const float scale, const tensor& src ); void add_cv_to_all_columns( float beta, tensor& dest, float alpha, const tensor& src ); // ----------------------------------------------------------------------------------- void affine_transform( tensor& dest, const tensor& src, const tensor& A, const tensor& B ); // ----------------------------------------------------------------------------------- void affine_transform_conv( tensor& dest, const tensor& src, const tensor& A, const tensor& B ); // ---------------------------------------------------------------------------------------- void compute_adam_update ( size_t begin, size_t end, tensor& s, tensor& m, tensor& v, const float t, const float learning_rate, const float weight_decay, const float momentum1, const float momentum2, const tensor& params, const tensor& params_grad ); // ----------------------------------------------------------------------------------- void assign_bias_gradient ( tensor& grad, const tensor& gradient_input ); // ----------------------------------------------------------------------------------- void layer_normalize ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const tensor& src, const tensor& gamma, const tensor& beta ); void layer_normalize_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad, resizable_tensor& dmeans, resizable_tensor& dvars ); // ----------------------------------------------------------------------------------- void rms_normalize( const double eps, resizable_tensor& dest, resizable_tensor& scale, const tensor& src, const tensor& gamma ); void rms_normalize_gradient( const tensor& gradient_input, const tensor& scale, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, resizable_tensor& dscale ); // ----------------------------------------------------------------------------------- void threshold ( tensor& data, float thresh ); // ---------------------------------------------------------------------------------------- void dot ( const tensor& a, const tensor& b, tensor& result, size_t idx ); // ---------------------------------------------------------------------------------------- void prelu ( tensor& dest, const tensor& src, const tensor& param ); void prelu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input, const tensor& param, tensor& params_grad ); // ---------------------------------------------------------------------------------------- void leaky_relu ( tensor& dest, const tensor& src, const float alpha ); void leaky_relu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input, const float alpha ); // ---------------------------------------------------------------------------------------- void mish ( tensor& dest, const tensor& src ); void mish_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ); // ---------------------------------------------------------------------------------------- void clipped_relu ( tensor& dest, const tensor& src, const float coef ); void clipped_relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float ceiling ); // ------------------------------------------------------------------------------------ void elu ( tensor& dest, const tensor& src, const float alpha ); void elu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float alpha ); // ---------------------------------------------------------------------------------------- void gelu ( tensor& dest, const tensor& src ); void gelu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ); // ---------------------------------------------------------------------------------------- void smelu ( tensor& dest, const tensor& src, const float beta ); void smelu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float beta ); // ---------------------------------------------------------------------------------------- void silu ( tensor& dest, const tensor& src ); void silu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ); // ------------------------------------------------------------------------------------ void resize_bilinear ( tensor& dest, long long dest_row_stride, long long dest_channel_stride, const tensor& src, long long src_row_stride, long long src_channel_stride ); void resize_bilinear_gradient ( tensor& grad, long long grad_row_stride, long long grad_channel_stride, const tensor& gradient_input, long long gradient_input_row_stride, long long gradient_input_channel_stride ); inline void resize_bilinear ( tensor& dest, const tensor& src ) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); } inline void resize_bilinear_gradient ( tensor& grad, const tensor& gradient_input ) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); } // ---------------------------------------------------------------------------------------- void reorg ( bool add_to, tensor& dest, const int row_stride, const int col_stride, const tensor& src ); void reorg_gradient ( bool add_to, tensor& grad, const int row_stride, const int col_stride, const tensor& gradient_input ); // ----------------------------------------------------------------------------------- void embeddings( resizable_tensor& dest, const tensor& src, const tensor& embs ); void embeddings_gradient( const tensor& prev, const tensor& gradient_input, tensor& grads, const tensor& freqs, float learning_rate, bool scale ); // ---------------------------------------------------------------------------------------- void copy_tensor( bool add_to, tensor& dest, size_t dest_k_offset, const tensor& src, size_t src_k_offset, size_t count_k ); // ---------------------------------------------------------------------------------------- void copy_tensor( bool add_to, tensor& dest, size_t dk, size_t dnr, size_t dnc, const tensor& src, size_t sk, size_t snr, size_t snc, size_t k, size_t nr, size_t nc ); // ---------------------------------------------------------------------------------------- void transpose( bool add_to, tensor& dest, const tensor& src ); // ---------------------------------------------------------------------------------------- void compute_act_halt_probabilities( resizable_tensor& halt_probs, resizable_tensor& logits, const tensor& input_data, const tensor& halt_params, long batch_size, long seq_len, long feature_dim ); void update_act_state( resizable_tensor& output, const tensor& input_data, const tensor& halt_probs, resizable_tensor& cumulative_halting, resizable_tensor& remainders, resizable_tensor& n_steps, resizable_tensor& effective_weights, long batch_size, long seq_len, long d_model, long num_channels, float halt_threshold, long current_step ); void finalize_act_output( resizable_tensor& output, const tensor& input_data, const tensor& remainders, resizable_tensor& effective_weights, long batch_size, long seq_len, long d_model, long num_channels ); void apply_act_depth_scaling( tensor& gradients, const tensor& n_steps, long batch_size, long seq_len, long d_model, long num_channels, float max_steps, float scale_factor ); // ---------------------------------------------------------------------------------------- class compute_loss_binary_log_per_pixel { /*! The point of this class is to compute the loss computed by loss_binary_log_per_pixel_, but to do so with CUDA. !*/ public: compute_loss_binary_log_per_pixel( ) { } template < typename const_label_iterator > void operator() ( const_label_iterator truth, const tensor& subnetwork_output, tensor& gradient, double& loss ) const { const auto image_size = subnetwork_output.nr()*subnetwork_output.nc(); const size_t bytes_per_plane = image_size*sizeof(float); // Allocate a cuda buffer to store all the truth images and also one float // for the scalar loss output. buf = device_global_buffer(subnetwork_output.num_samples()*bytes_per_plane + sizeof(float)); cuda_data_ptr loss_buf = static_pointer_cast(buf, 1); buf = buf+sizeof(float); // copy the truth data into a cuda buffer. for (long i = 0; i < subnetwork_output.num_samples(); ++i, ++truth) { const matrix& t = *truth; DLIB_ASSERT(t.nr() == subnetwork_output.nr()); DLIB_ASSERT(t.nc() == subnetwork_output.nc()); memcpy(buf + i*bytes_per_plane, &t(0,0), bytes_per_plane); } auto truth_buf = static_pointer_cast(buf, subnetwork_output.num_samples()*image_size); do_work(loss_buf, truth_buf, subnetwork_output, gradient, loss); } private: static void do_work( cuda_data_ptr loss_work_buffer, cuda_data_ptr truth_buffer, const tensor& subnetwork_output, tensor& gradient, double& loss ); mutable cuda_data_void_ptr buf; }; // ---------------------------------------------------------------------------------------- class compute_loss_multiclass_log_per_pixel { /*! The point of this class is to compute the loss computed by loss_multiclass_log_per_pixel_, but to do so with CUDA. !*/ public: compute_loss_multiclass_log_per_pixel( ) { } template < typename const_label_iterator > void operator() ( const_label_iterator truth, const tensor& subnetwork_output, tensor& gradient, double& loss ) const { const auto image_size = subnetwork_output.nr()*subnetwork_output.nc(); const size_t bytes_per_plane = image_size*sizeof(uint16_t); // Allocate a cuda buffer to store all the truth images and also one float // for the scalar loss output. buf = device_global_buffer(subnetwork_output.num_samples()*bytes_per_plane + sizeof(float)); cuda_data_ptr loss_buf = static_pointer_cast(buf, 1); buf = buf+sizeof(float); // copy the truth data into a cuda buffer. for (long i = 0; i < subnetwork_output.num_samples(); ++i, ++truth) { const matrix& t = *truth; DLIB_ASSERT(t.nr() == subnetwork_output.nr()); DLIB_ASSERT(t.nc() == subnetwork_output.nc()); memcpy(buf + i*bytes_per_plane, &t(0,0), bytes_per_plane); } auto truth_buf = static_pointer_cast(buf, subnetwork_output.num_samples()*image_size); do_work(loss_buf, truth_buf, subnetwork_output, gradient, loss); } private: static void do_work( cuda_data_ptr loss_work_buffer, cuda_data_ptr truth_buffer, const tensor& subnetwork_output, tensor& gradient, double& loss ); mutable cuda_data_void_ptr buf; }; // ---------------------------------------------------------------------------------------- class compute_loss_multiclass_log_per_pixel_weighted { /*! The point of this class is to compute the loss computed by loss_multiclass_log_per_pixel_weighted_, but to do so with CUDA. !*/ public: compute_loss_multiclass_log_per_pixel_weighted( ) { } template < typename const_label_iterator > void operator() ( const_label_iterator truth, const tensor& subnetwork_output, tensor& gradient, double& loss ) const { const auto image_size = subnetwork_output.nr()*subnetwork_output.nc(); const size_t bytes_per_plane = image_size*sizeof(uint16_t); const size_t weight_bytes_per_plane = image_size*sizeof(float); matrix labels(truth->nr(), truth->nc()); matrix weights(truth->nr(), truth->nc()); // Allocate a cuda buffer to store all the truth images and also one float // for the scalar loss output. buf = device_global_buffer(subnetwork_output.num_samples()*(bytes_per_plane + weight_bytes_per_plane) + sizeof(float)); cuda_data_ptr loss_buf = static_pointer_cast(buf, 1); buf = buf+sizeof(float); const auto truth_offset = subnetwork_output.num_samples() * weight_bytes_per_plane; // copy the truth data into a cuda buffer. for (long i = 0; i < subnetwork_output.num_samples(); ++i, ++truth) { const matrix>& t = *truth; DLIB_ASSERT(t.nr() == subnetwork_output.nr()); DLIB_ASSERT(t.nc() == subnetwork_output.nc()); for (long r = 0; r < t.nr(); ++r) { for (long c = 0; c < t.nc(); ++c) { labels(r, c) = t(r, c).label; weights(r, c) = t(r, c).weight; } } memcpy(buf + truth_offset + i*bytes_per_plane, &labels(0,0), bytes_per_plane); memcpy(buf + i*weight_bytes_per_plane, &weights(0, 0), weight_bytes_per_plane); } auto weights_buf = static_pointer_cast(buf, subnetwork_output.num_samples()*image_size); buf = buf+truth_offset; auto truth_buf = static_pointer_cast(buf, subnetwork_output.num_samples()*image_size); do_work(loss_buf, truth_buf, weights_buf, subnetwork_output, gradient, loss); } private: static void do_work( cuda_data_ptr loss_work_buffer, cuda_data_ptr truth_buffer, cuda_data_ptr weights_buffer, const tensor& subnetwork_output, tensor& gradient, double& loss ); mutable cuda_data_void_ptr buf; }; // ---------------------------------------------------------------------------------------- class compute_loss_mean_squared_per_channel_and_pixel { /*! The point of this class is to compute the loss computed by loss_mean_squared_per_channel_and_pixel_, but to do so with CUDA. !*/ public: compute_loss_mean_squared_per_channel_and_pixel( ) { } template < typename const_label_iterator > void operator() ( const_label_iterator truth, const tensor& subnetwork_output, tensor& gradient, double& loss ) const { const auto image_size = subnetwork_output.nr()*subnetwork_output.nc()*subnetwork_output.k(); const size_t bytes_per_image = image_size*sizeof(float); // Allocate a cuda buffer to store all the truth images and also one float // for the scalar loss output. buf = device_global_buffer(subnetwork_output.num_samples()*bytes_per_image + sizeof(float)); cuda_data_ptr loss_buf = static_pointer_cast(buf, 1); buf = buf+sizeof(float); const size_t bytes_per_plane = subnetwork_output.nr()*subnetwork_output.nc()*sizeof(float); // copy the truth data into a cuda buffer. for (long i = 0; i < subnetwork_output.num_samples(); ++i, ++truth) { const auto& t = *truth; DLIB_ASSERT(static_cast(t.size()) == subnetwork_output.k()); for (size_t j = 0; j < t.size(); ++j) { DLIB_ASSERT(t[j].nr() == subnetwork_output.nr()); DLIB_ASSERT(t[j].nc() == subnetwork_output.nc()); memcpy(buf + i*bytes_per_image + j*bytes_per_plane, &t[j](0,0), bytes_per_plane); } } auto truth_buf = static_pointer_cast(buf, subnetwork_output.num_samples()*image_size); do_work(loss_buf, truth_buf, subnetwork_output, gradient, loss); } private: static void do_work( cuda_data_ptr loss_work_buffer, cuda_data_ptr truth_buffer, const tensor& subnetwork_output, tensor& gradient, double& loss ); mutable cuda_data_void_ptr buf; }; // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ #else // if DLIB_USE_CUDA NOT DEFINED inline void set_device ( int id ) { DLIB_CASSERT(id == 0, "dlib::cuda::set_device(id) called with an invalid device id."); } inline int get_device ( ){ return 0; } inline int get_num_devices ( ) { return 1; } inline std::string get_device_name ( int device ) { DLIB_CASSERT(device == 0, "dlib::cuda::set_device(id) called with an invalid device id."); return "CUDA_DISABLED"; } inline void set_current_device_blocking_sync( ) {} inline bool can_access_peer (int , int ) { return false; } inline bool can_access_peer (const tensor& , const tensor& ) { return false; } inline void device_synchronize (int ){} inline void device_synchronize (const tensor& ){} class enable_peer_access { public: enable_peer_access() = delete; enable_peer_access(const enable_peer_access&) = delete; enable_peer_access& operator=(const enable_peer_access&) = delete; enable_peer_access( int, int ){} enable_peer_access( const tensor&, const tensor& ) {} }; #endif // DLIB_USE_CUDA } } #endif // DLIB_DNN_CuDA_H_ ================================================ FILE: dlib/cuda/cuda_errors.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CUDA_ERRORs_H_ #define DLIB_CUDA_ERRORs_H_ #include "../error.h" namespace dlib { struct cuda_error : public error { /*! WHAT THIS OBJECT REPRESENTS This is the exception thrown if any calls to the NVIDIA CUDA runtime returns an error. !*/ cuda_error(const std::string& message): error(message) {} }; struct cudnn_error : public cuda_error { /*! WHAT THIS OBJECT REPRESENTS This is the exception thrown if any calls to the NVIDIA cuDNN library returns an error. !*/ cudnn_error(const std::string& message): cuda_error(message) {} }; struct curand_error : public cuda_error { /*! WHAT THIS OBJECT REPRESENTS This is the exception thrown if any calls to the NVIDIA cuRAND library returns an error. !*/ curand_error(const std::string& message): cuda_error(message) {} }; struct cublas_error : public cuda_error { /*! WHAT THIS OBJECT REPRESENTS This is the exception thrown if any calls to the NVIDIA cuBLAS library returns an error. !*/ cublas_error(const std::string& message): cuda_error(message) {} }; struct cusolver_error : public cuda_error { /*! WHAT THIS OBJECT REPRESENTS This is the exception thrown if any calls to the NVIDIA cuSolver library returns an error. !*/ cusolver_error(const std::string& message): cuda_error(message) {} }; } #endif // DLIB_CUDA_ERRORs_H_ ================================================ FILE: dlib/cuda/cuda_utils.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CUDA_UtILS_H_ #define DLIB_CUDA_UtILS_H_ #include "../algs.h" #ifndef DLIB_USE_CUDA #error "This file shouldn't be #included unless DLIB_USE_CUDA is #defined" #endif #include "cuda_errors.h" #include #include #include #include #include #include #include // Check the return value of a call to the CUDA runtime for an error condition. #define CHECK_CUDA(call) \ do{ \ const cudaError_t error = call; \ if (error != cudaSuccess) \ { \ std::ostringstream sout; \ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ sout << "code: " << cudaGetLastError() << ", reason: " << cudaGetErrorString(error);\ throw dlib::cuda_error(sout.str()); \ } \ }while(false) // ---------------------------------------------------------------------------------------- #ifdef __CUDACC__ namespace dlib { namespace cuda { // ------------------------------------------------------------------------------------ __inline__ __device__ size_t pack_idx ( size_t dim_size3, size_t dim_size2, size_t dim_size1, size_t idx4, size_t idx3, size_t idx2, size_t idx1 ) /*! ensures - Converts a 4D array index into a 1D index assuming row major layout. To understand precisely what this function does, imagine we had an array declared like this: int ARRAY[anything][dim_size3][dim_size2][dim_size1]; Then we could index it like this: ARRAY[idx4][idx3][idx2][idx1] or equivalently like this: ((int*)ARRAY)[pack_idx(dim_size3,dim_size2,dim_size1, idx4,idx3,idx2,idx1)] !*/ { return ((idx4*dim_size3 + idx3)*dim_size2 + idx2)*dim_size1 + idx1; } __inline__ __device__ void unpack_idx ( size_t idx, size_t dim_size3, size_t dim_size2, size_t dim_size1, size_t& idx4, size_t& idx3, size_t& idx2, size_t& idx1 ) /*! ensures - This function computes the inverse of pack_idx(). Therefore, if PACKED == pack_idx(dim_size3,dim_size2,dim_size1, idx4,idx3,idx2,idx1) then unpack_idx(PACKED,dim_size3,dim_size2,dim_size1, IDX4,IDX3,IDX2,IDX1) results in: - IDX1 == idx1 - IDX2 == idx2 - IDX3 == idx3 - IDX4 == idx4 !*/ { idx1 = idx%dim_size1; idx /= dim_size1; idx2 = idx%dim_size2; idx /= dim_size2; idx3 = idx%dim_size3; idx /= dim_size3; idx4 = idx; } // ------------------------------------------------------------------------------------ // This function is from the article: // http://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ __inline__ __device__ float warp_reduce_sum(float val) { for (int offset = warpSize/2; offset > 0; offset /= 2) #if CUDART_VERSION >= 9000 val += __shfl_down_sync(0xFFFFFFFF,val, offset); #else val += __shfl_down(val, offset); #endif return val; } __inline__ __device__ bool is_first_thread_in_warp() { return (threadIdx.x & (warpSize - 1)) == 0; } __inline__ __device__ void warp_reduce_atomic_add( float& out, float val ) /*! ensures - Atomically adds all the val variables in the current warp to out. See this page for an extended discussion: http://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/ !*/ { val = warp_reduce_sum(val); if (is_first_thread_in_warp()) atomicAdd(&out, val); } // ------------------------------------------------------------------------------------ struct max_jobs { max_jobs(int x) : num_x(x) {} max_jobs(int x, int y) : num_x(x), num_y(y) {} int num_x; int num_y = 1; }; template void launch_kernel ( Kernel K, T ...args ) /*! ensures - launches the given kernel K(args...). The point of this function is to automatically set the kernel launch parameters to something reasonable based on the properties of the kernel and the current GPU card. !*/ { int num_blocks, num_threads; CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&num_blocks,&num_threads,K)); K<<>>(args...); } template void launch_kernel ( Kernel K, max_jobs m, T ...args ) /*! ensures - This function is just like launch_kernel(K,args...) except that you can additionally supply a max_jobs number that tells it how many possible total threads could be used. This is useful when launching potentially small jobs that might not need the number of threads suggested by launch_kernel(). !*/ { if (m.num_x == 0 || m.num_y == 0) return; int num_blocks, num_threads; CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&num_blocks,&num_threads,K)); // Check if the job is really small and we don't really need to launch a kernel // with this many blocks and threads. if (num_blocks*num_threads > m.num_x*m.num_y) num_blocks = (m.num_x*m.num_y+num_threads-1)/num_threads; if (m.num_y == 1) { K<<>>(args...); } else { /* In general, the reason m.num_y!=1 (i.e. the reason you are in this code path) is because we are using nested grid-stride loops. There are two important things to note about what we are doing here. To illustrate them we will talk about this little CUDA code snippet: // initialize out before we begin. for (auto i : grid_stride_range_y(0, nr)) for (auto j : grid_stride_range(0, 1)) out[i] = 0; __syncthreads(); // synchronize threads in block // loop over some 2D thing and sum and store things into out. for (auto i : grid_stride_range_y(0, nr)) { float temp = 0; for (auto j : grid_stride_range(0, nc)) temp += whatever[i*nc+j]; // store the sum into out[i] warp_reduce_atomic_add(out[i], temp); } First, we make sure the number of x threads is a multiple of 32 so that you can use warp_reduce_atomic_add() inside the y loop. Second, we put the x block size to 1 so inter-block synchronization is easier. For example, if the number of x blocks wasn't 1 the above code would have a race condition in it. This is because the execution of out[i]=0 would be done by blocks with blockIdx.x==0, but then in the second set of loops, *all* the x blocks use out[i]. Since __syncthreads() doesn't do any synchronization between blocks some of the blocks might begin before the out[i]=0 statements finished and that would be super bad. */ // Try and make sure that the ratio of x to y threads is reasonable based // on the respective size of our loops. int x_threads = 32; int y_threads = num_threads/32; const int ratio = static_cast(std::round(put_in_range(1, y_threads, m.num_x/(double)m.num_y))); x_threads *= ratio; y_threads /= ratio; dim3 blocks(1,num_blocks); dim3 threads(x_threads,y_threads); K<<>>(args...); } } // ------------------------------------------------------------------------------------ class grid_stride_range { /*! WHAT THIS OBJECT REPRESENTS This is a tool for making a for loop that loops over an entire block of memory inside a kernel, but doing so in a way that parallelizes appropriately across all the threads in a kernel launch. For example, the following kernel would add the vector a to the vector b and store the output in out (assuming all vectors are of dimension n): __global__ void add_arrays( const float* a, const float* b, float* out, size_t n ) { for (auto i : grid_stride_range(0, n)) { out[i] = a[i]+b[i]; } } !*/ public: __device__ grid_stride_range( size_t ibegin_, size_t iend_ ) : ibegin(ibegin_), iend(iend_) {} class iterator { public: __device__ iterator() {} __device__ iterator(size_t pos_) : pos(pos_) {} __device__ size_t operator*() const { return pos; } __device__ iterator& operator++() { pos += gridDim.x * blockDim.x; return *this; } __device__ bool operator!=(const iterator& item) const { return pos < item.pos; } private: size_t pos; }; __device__ iterator begin() const { return iterator(ibegin+blockDim.x * blockIdx.x + threadIdx.x); } __device__ iterator end() const { return iterator(iend); } private: size_t ibegin; size_t iend; }; // ------------------------------------------------------------------------------------ class grid_stride_range_y { /*! WHAT THIS OBJECT REPRESENTS This object is just like grid_stride_range except that it looks at CUDA's y thread index (e.g. threadIdx.y) instead of the x index. Therefore, if you launch a cuda kernel with a statement like: dim3 blocks(1,10); dim3 threads(32,32); // You need to have x and y not equal to 1 to get parallelism over both loops. add_arrays<<>>(a,b,out,nr,nc); You can perform a nested 2D parallel for loop rather than doing just a 1D for loop. So the code in the kernel would look like this if you wanted to add two 2D matrices: __global__ void add_arrays( const float* a, const float* b, float* out, size_t nr, size_t nc ) { for (auto r : grid_stride_range_y(0, nr)) { for (auto c : grid_stride_range(0, nc)) { auto i = r*nc+c; out[i] = a[i]+b[i]; } } } !*/ public: __device__ grid_stride_range_y( size_t ibegin_, size_t iend_ ) : ibegin(ibegin_), iend(iend_) {} class iterator { public: __device__ iterator() {} __device__ iterator(size_t pos_) : pos(pos_) {} __device__ size_t operator*() const { return pos; } __device__ iterator& operator++() { pos += gridDim.y * blockDim.y; return *this; } __device__ bool operator!=(const iterator& item) const { return pos < item.pos; } private: size_t pos; }; __device__ iterator begin() const { return iterator(ibegin+blockDim.y * blockIdx.y + threadIdx.y); } __device__ iterator end() const { return iterator(iend); } private: size_t ibegin; size_t iend; }; // ------------------------------------------------------------------------------------ } } #endif // __CUDACC__ // ---------------------------------------------------------------------------------------- #endif // DLIB_CUDA_UtILS_H_ ================================================ FILE: dlib/cuda/cudnn_dlibapi.cpp ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuDNN_CPP_ #define DLIB_DNN_CuDNN_CPP_ #ifdef DLIB_USE_CUDA #include "cudnn_dlibapi.h" #include "tensor.h" #include #include #include #include #include #include #include "cuda_utils.h" #include "cpu_dlib.h" #include "cuda_dlib.h" #include "tensor_tools.h" static const char* cudnn_get_error_string(cudnnStatus_t s) { switch(s) { case CUDNN_STATUS_NOT_INITIALIZED: return "CUDA Runtime API initialization failed."; case CUDNN_STATUS_ALLOC_FAILED: return "CUDA Resources could not be allocated."; case CUDNN_STATUS_BAD_PARAM: return "CUDNN_STATUS_BAD_PARAM"; case CUDNN_STATUS_EXECUTION_FAILED: return "CUDNN_STATUS_EXECUTION_FAILED"; case CUDNN_STATUS_NOT_SUPPORTED: return "CUDNN_STATUS_NOT_SUPPORTED"; case CUDNN_STATUS_ARCH_MISMATCH: return "CUDNN_STATUS_ARCH_MISMATCH: Your GPU is too old and not supported by cuDNN"; default: return "A call to cuDNN failed"; } } // Check the return value of a call to the cuDNN runtime for an error condition. #define CHECK_CUDNN(call) \ do{ \ const cudnnStatus_t error = call; \ if (error != CUDNN_STATUS_SUCCESS) \ { \ std::ostringstream sout; \ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ sout << "code: " << error << ", reason: " << cudnn_get_error_string(error);\ throw dlib::cudnn_error(sout.str()); \ } \ }while(false) namespace dlib { namespace cuda { // ------------------------------------------------------------------------------------ static cudnnTensorDescriptor_t descriptor(const tensor& t) { return (const cudnnTensorDescriptor_t)t.get_cudnn_tensor_descriptor().get_handle(); } static cudnnTensorDescriptor_t descriptor(const tensor_descriptor& t) { return (const cudnnTensorDescriptor_t)t.get_handle(); } // ------------------------------------------------------------------------------------ class cudnn_context { public: // not copyable cudnn_context(const cudnn_context&) = delete; cudnn_context& operator=(const cudnn_context&) = delete; cudnn_context() { handles.resize(16); } ~cudnn_context() { for (auto h : handles) { if (h) cudnnDestroy(h); } } cudnnHandle_t get_handle ( ) { int new_device_id; CHECK_CUDA(cudaGetDevice(&new_device_id)); // make room for more devices if needed if (new_device_id >= (long)handles.size()) handles.resize(new_device_id+16); // If we don't have a handle already for this device then make one if (!handles[new_device_id]) CHECK_CUDNN(cudnnCreate(&handles[new_device_id])); // Finally, return the handle for the current device return handles[new_device_id]; } private: std::vector handles; }; static cudnnHandle_t context() { thread_local cudnn_context c; return c.get_handle(); } // ------------------------------------------------------------------------------------ class cudnn_activation_descriptor { public: // not copyable cudnn_activation_descriptor(const cudnn_activation_descriptor&) = delete; cudnn_activation_descriptor& operator=(const cudnn_activation_descriptor&) = delete; cudnn_activation_descriptor( cudnnActivationMode_t mode, cudnnNanPropagation_t reluNanOpt, double coef ) { CHECK_CUDNN(cudnnCreateActivationDescriptor(&handle)); CHECK_CUDNN(cudnnSetActivationDescriptor(handle, mode, reluNanOpt, coef)); } ~cudnn_activation_descriptor() { cudnnDestroyActivationDescriptor(handle); } cudnnActivationDescriptor_t get_handle ( ) { return handle; } private: cudnnActivationDescriptor_t handle; }; static cudnnActivationDescriptor_t identity_activation_descriptor() { thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_IDENTITY, CUDNN_PROPAGATE_NAN,0); return des.get_handle(); } static cudnnActivationDescriptor_t relu_activation_descriptor() { thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN,0); return des.get_handle(); } static cudnnActivationDescriptor_t sigmoid_activation_descriptor() { thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN,0); return des.get_handle(); } static cudnnActivationDescriptor_t tanh_activation_descriptor() { thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN,0); return des.get_handle(); } // ------------------------------------------------------------------------------------ tensor_descriptor:: tensor_descriptor( ) : handle(nullptr) { } tensor_descriptor:: ~tensor_descriptor() { set_size(0,0,0,0); } void tensor_descriptor:: set_size( int n, int k, int nr, int nc ) { if (handle) { cudnnDestroyTensorDescriptor((cudnnTensorDescriptor_t)handle); handle = nullptr; } if (n != 0 && nr != 0 && nc != 0 && k != 0) { cudnnTensorDescriptor_t h; CHECK_CUDNN(cudnnCreateTensorDescriptor(&h)); handle = h; CHECK_CUDNN(cudnnSetTensor4dDescriptor((cudnnTensorDescriptor_t)handle, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, k, nr, nc)); } } void tensor_descriptor:: get_size ( int& n, int& k, int& nr, int& nc ) const { if (handle) { int nStride, cStride, hStride, wStride; cudnnDataType_t datatype; CHECK_CUDNN(cudnnGetTensor4dDescriptor((cudnnTensorDescriptor_t)handle, &datatype, &n, &k, &nr, &nc, &nStride, &cStride, &hStride, &wStride)); } else { n = 0; k = 0; nr = 0; nc = 0; } } // ------------------------------------------------------------------------------------ void add( float beta, tensor& dest, float alpha, const tensor& src ) { DLIB_CASSERT( (have_same_dimensions(src, dest) || (src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1) || (src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()) || (src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()) || (src.num_samples()==dest.num_samples() && src.k()==1 && src.nr()==1 && src.nc()==1)) && is_same_object(src,dest) == false , "\n\t dest.num_samples(): " << dest.num_samples() <<"\n\t dest.k(): " << dest.k() <<"\n\t dest.nr(): " << dest.nr() <<"\n\t dest.nc(): " << dest.nc() <<"\n\t src.num_samples(): " << src.num_samples() <<"\n\t src.k(): " << src.k() <<"\n\t src.nr(): " << src.nr() <<"\n\t src.nc(): " << src.nc() ); if (dest.size() == src.size() && beta == 1) { // Call the dlib function in this case since it's faster than the one that // comes with cuDNN (at least as of cuDNN v4). add_scaled(dest, alpha, src); return; } else if (src.num_samples()==dest.num_samples() && src.k()==1 && src.nr()==1 && src.nc()==1) { add_cv_to_all_columns(beta, dest, alpha, src); return; } CHECK_CUDNN(cudnnAddTensor(context(), &alpha, descriptor(src), src.device(), &beta, descriptor(dest), dest.device())); } void assign_conv_bias_gradient ( tensor& grad, const tensor& gradient_input ) { DLIB_CASSERT( grad.num_samples() == 1 && grad.k() >= 1 && grad.nr() == 1 && grad.nc() == 1 && gradient_input.k() == grad.k() && gradient_input.size() > 0 && is_same_object(grad,gradient_input) == false ); const float alpha = 1; const float beta = 0; CHECK_CUDNN(cudnnConvolutionBackwardBias(context(), &alpha, descriptor(gradient_input), gradient_input.device(), &beta, descriptor(grad), grad.device())); } // ------------------------------------------------------------------------------------ void batch_normalize_inference ( const double eps, resizable_tensor& dest, const tensor& src, const tensor& gamma, const tensor& beta, const tensor& running_means, const tensor& running_variances ) { DLIB_CASSERT( gamma.num_samples() == 1 && gamma.nr() == src.nr() && gamma.nc() == src.nc() && gamma.k() == src.k() && have_same_dimensions(gamma, beta) && have_same_dimensions(gamma, running_means) && have_same_dimensions(gamma, running_variances) && eps > 0, "\ngamma.num_samples(): " << gamma.num_samples() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\nbeta.num_samples(): " << beta.num_samples() << "\nbeta.k(): " << beta.k() << "\nbeta.nr(): " << beta.nr() << "\nbeta.nc(): " << beta.nc() << "\nrunning_means.num_samples(): " << running_means.num_samples() << "\nrunning_means.k(): " << running_means.k() << "\nrunning_means.nr(): " << running_means.nr() << "\nrunning_means.nc(): " << running_means.nc() << "\nrunning_variances.num_samples(): " << running_variances.num_samples() << "\nrunning_variances.k(): " << running_variances.k() << "\nrunning_variances.nr(): " << running_variances.nr() << "\nrunning_variances.nc(): " << running_variances.nc() << "\nsrc.k(): " << src.k() << "\nsrc.nr(): " << src.nr() << "\nsrc.nc(): " << src.nc() << "\neps: " << eps ); const float in_scale = 1; const float out_scale = 0; dest.copy_size(src); CHECK_CUDNN(cudnnBatchNormalizationForwardInference( context(), CUDNN_BATCHNORM_PER_ACTIVATION, &in_scale, &out_scale, descriptor(src), src.device(), descriptor(dest), dest.device(), descriptor(gamma), gamma.device(), beta.device(), running_means.device(), running_variances.device(), eps)); } void batch_normalize ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const double averaging_factor, resizable_tensor& running_means, resizable_tensor& running_variances, const tensor& src, const tensor& gamma, const tensor& beta ) { DLIB_CASSERT(0 <= averaging_factor && averaging_factor <= 1, "averaging_factor: " << averaging_factor); DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_means,means)); DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_variances,invstds)); DLIB_CASSERT( src.num_samples() > 1 && gamma.num_samples() == 1 && beta.num_samples() == 1 && gamma.nr() == beta.nr() && beta.nr() == src.nr() && gamma.nc() == beta.nc() && beta.nc() == src.nc() && gamma.k() == beta.k() && beta.k() == src.k() && eps > 0, "\ngamma.num_samples(): " << gamma.num_samples() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\nbeta.num_samples(): " << beta.num_samples() << "\nbeta.k(): " << beta.k() << "\nbeta.nr(): " << beta.nr() << "\nbeta.nc(): " << beta.nc() << "\nsrc.k(): " << src.k() << "\nsrc.nr(): " << src.nr() << "\nsrc.nc(): " << src.nc() << "\neps: " << eps ); const float in_scale = 1; const float out_scale = 0; dest.copy_size(src); means.set_size(1, src.k(), src.nr(), src.nc()); invstds.copy_size(means); running_means.copy_size(means); running_variances.copy_size(means); // cuDNN requires that running_means and running_variances be initialized to // some valid float values even if the averaging factor would have ignored // them. if (averaging_factor == 1) { running_means = 0; running_variances = 1; } CHECK_CUDNN(cudnnBatchNormalizationForwardTraining( context(), CUDNN_BATCHNORM_PER_ACTIVATION, &in_scale, &out_scale, descriptor(src), src.device(), descriptor(dest), dest.device(), descriptor(gamma), gamma.device(), beta.device(), averaging_factor, running_means.device(), running_variances.device(), eps, means.device(), invstds.device())); } void batch_normalize_gradient( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad ) { const long num = src.k()*src.nr()*src.nc(); DLIB_CASSERT(src.num_samples() > 1); DLIB_CASSERT(num == (long)means.size()); DLIB_CASSERT(num == (long)invstds.size()); DLIB_CASSERT(num == (long)gamma.size()); DLIB_CASSERT(num == (long)gamma_grad.size()); DLIB_CASSERT(num == (long)beta_grad.size()); DLIB_CASSERT(have_same_dimensions(gradient_input, src)); DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); DLIB_CASSERT(eps > 0); const float in_scale = 1; const float out_scale = 1; const float in_scale_params = 1; const float out_scale_params = 0; CHECK_CUDNN(cudnnBatchNormalizationBackward( context(), CUDNN_BATCHNORM_PER_ACTIVATION, &in_scale, &out_scale, &in_scale_params, &out_scale_params, descriptor(src), src.device(), descriptor(gradient_input), gradient_input.device(), descriptor(src_grad), src_grad.device(), descriptor(gamma), gamma.device(), gamma_grad.device(), beta_grad.device(), eps, means.device(), invstds.device())); } // ------------------------------------------------------------------------------------ void batch_normalize_conv_inference ( const double eps, resizable_tensor& dest, const tensor& src, const tensor& gamma, const tensor& beta, const tensor& running_means, const tensor& running_variances ) { DLIB_CASSERT( gamma.num_samples() == 1 && gamma.nr() == 1 && gamma.nc() == 1 && gamma.k() == src.k() && have_same_dimensions(gamma, beta) && have_same_dimensions(gamma, running_means) && have_same_dimensions(gamma, running_variances) && eps > 0, "\ngamma.num_samples(): " << gamma.num_samples() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\nbeta.num_samples(): " << beta.num_samples() << "\nbeta.k(): " << beta.k() << "\nbeta.nr(): " << beta.nr() << "\nbeta.nc(): " << beta.nc() << "\nrunning_means.num_samples(): " << running_means.num_samples() << "\nrunning_means.k(): " << running_means.k() << "\nrunning_means.nr(): " << running_means.nr() << "\nrunning_means.nc(): " << running_means.nc() << "\nrunning_variances.num_samples(): " << running_variances.num_samples() << "\nrunning_variances.k(): " << running_variances.k() << "\nrunning_variances.nr(): " << running_variances.nr() << "\nrunning_variances.nc(): " << running_variances.nc() << "\nsrc.k(): " << src.k() << "\nsrc.nr(): " << src.nr() << "\nsrc.nc(): " << src.nc() << "\neps: " << eps ); const float in_scale = 1; const float out_scale = 0; dest.copy_size(src); CHECK_CUDNN(cudnnBatchNormalizationForwardInference( context(), CUDNN_BATCHNORM_SPATIAL, &in_scale, &out_scale, descriptor(src), src.device(), descriptor(dest), dest.device(), descriptor(gamma), gamma.device(), beta.device(), running_means.device(), running_variances.device(), eps)); } void batch_normalize_conv ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const double averaging_factor, resizable_tensor& running_means, resizable_tensor& running_variances, const tensor& src, const tensor& gamma, const tensor& beta ) { DLIB_CASSERT(0 <= averaging_factor && averaging_factor <= 1, "averaging_factor: " << averaging_factor); DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_means,means)); DLIB_CASSERT(averaging_factor==1 || have_same_dimensions(running_variances,invstds)); DLIB_CASSERT( src.num_samples() > 1 && gamma.num_samples() == 1 && beta.num_samples() == 1 && gamma.nr() == 1 && beta.nr() == 1 && gamma.nc() == 1 && beta.nc() == 1 && gamma.k() == beta.k() && beta.k() == src.k() && eps > 0, "\ngamma.num_samples(): " << gamma.num_samples() << "\ngamma.k(): " << gamma.k() << "\ngamma.nr(): " << gamma.nr() << "\ngamma.nc(): " << gamma.nc() << "\nbeta.num_samples(): " << beta.num_samples() << "\nbeta.k(): " << beta.k() << "\nbeta.nr(): " << beta.nr() << "\nbeta.nc(): " << beta.nc() << "\nsrc.k(): " << src.k() << "\nsrc.nr(): " << src.nr() << "\nsrc.nc(): " << src.nc() << "\neps: " << eps ); const float in_scale = 1; const float out_scale = 0; dest.copy_size(src); means.set_size(1, src.k()); invstds.copy_size(means); running_means.copy_size(means); running_variances.copy_size(means); // cuDNN requires that running_means and running_variances be initialized to // some valid float values even if the averaging factor would have ignored // them. if (averaging_factor == 1) { running_means = 0; running_variances = 1; } CHECK_CUDNN(cudnnBatchNormalizationForwardTraining( context(), CUDNN_BATCHNORM_SPATIAL, &in_scale, &out_scale, descriptor(src), src.device(), descriptor(dest), dest.device(), descriptor(gamma), gamma.device(), beta.device(), averaging_factor, running_means.device(), running_variances.device(), eps, means.device(), invstds.device())); } void batch_normalize_conv_gradient( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad ) { DLIB_CASSERT(src.k() == (long)means.size()); DLIB_CASSERT(src.k() == (long)invstds.size()); DLIB_CASSERT(src.k() == (long)gamma.size()); DLIB_CASSERT(src.k() == (long)gamma_grad.size()); DLIB_CASSERT(src.k() == (long)beta_grad.size()); DLIB_CASSERT(have_same_dimensions(gradient_input, src)); DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); DLIB_CASSERT(eps > 0); const float in_scale = 1; const float out_scale = 1; const float in_scale_params = 1; const float out_scale_params = 0; CHECK_CUDNN(cudnnBatchNormalizationBackward( context(), CUDNN_BATCHNORM_SPATIAL, &in_scale, &out_scale, &in_scale_params, &out_scale_params, descriptor(src), src.device(), descriptor(gradient_input), gradient_input.device(), descriptor(src_grad), src_grad.device(), descriptor(gamma), gamma.device(), gamma_grad.device(), beta_grad.device(), eps, means.device(), invstds.device())); } // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ tensor_conv:: tensor_conv( ) : filter_handle(nullptr), conv_handle(nullptr), forward_algo(0), backward_data_algo(0), backward_filters_algo(0) { clear(); } void tensor_conv:: clear ( ) { if (filter_handle) cudnnDestroyFilterDescriptor((cudnnFilterDescriptor_t)filter_handle); if (conv_handle) cudnnDestroyConvolutionDescriptor((cudnnConvolutionDescriptor_t)conv_handle); filter_handle = nullptr; conv_handle = nullptr; out_num_samples = 0; out_k = 0; out_nr = 0; out_nc = 0; stride_y = 0; stride_x = 0; padding_y = 0; padding_x = 0; data_num_samples = 0; data_k = 0; data_nr = 0; data_nc = 0; filters_num_samples = 0; filters_k = 0; filters_nr = 0; filters_nc = 0; forward_algo = 0; backward_data_algo = 0; backward_filters_algo = 0; forward_workspace_size_in_bytes = 0; backward_data_workspace_size_in_bytes = 0; backward_filters_workspace_size_in_bytes = 0; forward_workspace.reset(); backward_data_workspace.reset(); backward_filters_workspace.reset(); } // Given an array of cudnn algorithm performance results, like // cudnnConvolutionFwdAlgoPerf_t, pick the best one to use. template decltype(std::declval().algo) pick_best_algorithm(const std::vector &perf_results) { DLIB_CASSERT(!perf_results.empty()); CHECK_CUDNN(perf_results[0].status); if (dnn_prefer_fastest_algorithms()) return perf_results[0].algo; // Otherwise we find the algorithm that has a good status and uses the least amount // of memory. size_t best_memory = std::numeric_limits::max(); decltype(std::declval().algo) best_alg; for (auto&& perf : perf_results) { if (perf.status == CUDNN_STATUS_SUCCESS && perf.memory < best_memory) { best_memory = perf.memory; best_alg = perf.algo; } } return best_alg; } void tensor_conv:: select_best_algorithms ( const tensor& data, const tensor_descriptor& dest_desc, allow_cache_use allow_cache_use_ ) { // Calling the cuDNN "find the best algorithm" functions is really slow. So we keep a // cache that tells us what method was best for a particular configuration. thread_local std::map, std::tuple> config_to_algo_cache; // If we have already found good algorithms for this setting then just pull them from // the cache. const auto cache_key = std::make_tuple(stride_y, stride_x, padding_y, padding_x, filters_nr, filters_nc); const auto iter = config_to_algo_cache.find(cache_key); if (iter != config_to_algo_cache.end() && allow_cache_use_ == allow_cache_use::yes) { std::tie(forward_algo, backward_data_algo, backward_filters_algo) = iter->second; return; } // Pick which forward algorithm we will use and allocate the necessary // workspace buffer. cudnnConvolutionFwdAlgo_t forward_best_algo; #if CUDNN_MAJOR >= 8 { int num_possible_algorithms = 0; CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(context(), &num_possible_algorithms)); std::vector perf_results(num_possible_algorithms); int num_algorithms = 0; CHECK_CUDNN(cudnnFindConvolutionForwardAlgorithm( context(), descriptor(data), (const cudnnFilterDescriptor_t)filter_handle, (const cudnnConvolutionDescriptor_t)conv_handle, descriptor(dest_desc), num_possible_algorithms, &num_algorithms, perf_results.data())); perf_results.resize(num_algorithms); forward_best_algo = pick_best_algorithm(perf_results); } #else CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm( context(), descriptor(data), (const cudnnFilterDescriptor_t)filter_handle, (const cudnnConvolutionDescriptor_t)conv_handle, descriptor(dest_desc), dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_FWD_PREFER_FASTEST:CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, std::numeric_limits::max(), &forward_best_algo)); #endif forward_algo = forward_best_algo; // Pick which backward data algorithm we will use and allocate the // necessary workspace buffer. cudnnConvolutionBwdDataAlgo_t backward_data_best_algo; #if CUDNN_MAJOR >= 8 { int num_possible_algorithms = 0; CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possible_algorithms)); std::vector perf_results(num_possible_algorithms); int num_algorithms = 0; CHECK_CUDNN(cudnnFindConvolutionBackwardDataAlgorithm( context(), (const cudnnFilterDescriptor_t)filter_handle, descriptor(dest_desc), (const cudnnConvolutionDescriptor_t)conv_handle, descriptor(data), num_possible_algorithms, &num_algorithms, perf_results.data())); perf_results.resize(num_algorithms); backward_data_best_algo = pick_best_algorithm(perf_results); } #else CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm( context(), (const cudnnFilterDescriptor_t)filter_handle, descriptor(dest_desc), (const cudnnConvolutionDescriptor_t)conv_handle, descriptor(data), dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE, std::numeric_limits::max(), &backward_data_best_algo)); #endif backward_data_algo = backward_data_best_algo; // Pick which backward filters algorithm we will use and allocate the // necessary workspace buffer. cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo; #if CUDNN_MAJOR >= 8 { int num_possible_algorithms = 0; CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possible_algorithms)); std::vector perf_results(num_possible_algorithms); int num_algorithms = 0; CHECK_CUDNN(cudnnFindConvolutionBackwardFilterAlgorithm( context(), descriptor(data), descriptor(dest_desc), (const cudnnConvolutionDescriptor_t)conv_handle, (const cudnnFilterDescriptor_t)filter_handle, num_possible_algorithms, &num_algorithms, perf_results.data())); perf_results.resize(num_algorithms); backward_filters_best_algo = pick_best_algorithm(perf_results); } #else CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm( context(), descriptor(data), descriptor(dest_desc), (const cudnnConvolutionDescriptor_t)conv_handle, (const cudnnFilterDescriptor_t)filter_handle, dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE, std::numeric_limits::max(), &backward_filters_best_algo)); #endif #if CUDNN_MAJOR < 7 // cuDNN 5.1 has a bug that causes // cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd // algorithm even for cases where cuDNN doesn't support it, leading to // incorrect outputs. So here we check if we are in a case where winograd // isn't supported and manually overrule // cudnnGetConvolutionBackwardFilterAlgorithm() by picking a safe // algorithm. if (dnn_prefer_fastest_algorithms() && !(stride_x == 1 && stride_y == 1 && ((filters_nr==3&&filters_nc==3) || (filters_nr==5&&filters_nc==5))) ) { backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; } #endif backward_filters_algo = backward_filters_best_algo; // Save this algorithm selection in the cache config_to_algo_cache[cache_key] = std::make_tuple(forward_algo, backward_data_algo, backward_filters_algo); } void tensor_conv:: update_convolution_data_workspace_sizes( const tensor& data, const tensor_descriptor& dest_desc ) { CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize( context(), descriptor(data), (const cudnnFilterDescriptor_t)filter_handle, (const cudnnConvolutionDescriptor_t)conv_handle, descriptor(dest_desc), (cudnnConvolutionFwdAlgo_t)forward_algo, &forward_workspace_size_in_bytes)); CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize( context(), (const cudnnFilterDescriptor_t)filter_handle, descriptor(dest_desc), (const cudnnConvolutionDescriptor_t)conv_handle, descriptor(data), (cudnnConvolutionBwdDataAlgo_t)backward_data_algo, &backward_data_workspace_size_in_bytes)); CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize( context(), descriptor(data), descriptor(dest_desc), (const cudnnConvolutionDescriptor_t)conv_handle, (const cudnnFilterDescriptor_t)filter_handle, (cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo, &backward_filters_workspace_size_in_bytes)); } void tensor_conv:: setup( const tensor& data, const tensor& filters, int stride_y_, int stride_x_, int padding_y_, int padding_x_ ) { DLIB_CASSERT(data.k() == filters.k()); // if the last call to setup gave the same exact settings then don't do // anything. if (data_num_samples == data.num_samples() && data_k == data.k() && data_nr == data.nr() && data_nc == data.nc() && stride_y_ == stride_y && stride_x_ == stride_x && padding_y_ == padding_y && padding_x_ == padding_x && filters_num_samples == filters.num_samples() && filters_k == filters.k() && filters_nr == filters.nr() && filters_nc == filters.nc() ) { return; } clear(); try { stride_y = stride_y_; stride_x = stride_x_; padding_y = padding_y_; padding_x = padding_x_; data_num_samples = data.num_samples(); data_k = data.k(); data_nr = data.nr(); data_nc = data.nc(); filters_num_samples = filters.num_samples(); filters_k = filters.k(); filters_nr = filters.nr(); filters_nc = filters.nc(); CHECK_CUDNN(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle)); CHECK_CUDNN(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, filters.num_samples(), filters.k(), filters.nr(), filters.nc())); CHECK_CUDNN(cudnnCreateConvolutionDescriptor((cudnnConvolutionDescriptor_t*)&conv_handle)); #if CUDNN_MAJOR >= 6 CHECK_CUDNN(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle, padding_y, // vertical padding padding_x, // horizontal padding stride_y, stride_x, 1, 1, // must be 1,1 CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT)); // could also be CUDNN_CONVOLUTION #else CHECK_CUDNN(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle, padding_y, // vertical padding padding_x, // horizontal padding stride_y, stride_x, 1, 1, // must be 1,1 CUDNN_CROSS_CORRELATION)); // could also be CUDNN_CONVOLUTION #endif #if CUDNN_MAJOR >= 8 // On Ampere and later GPUs, CUDNN_DEFAULT_MATH permits TF32 Tensor Core // operations which have reduced precision. Use CUDNN_FMA_MATH to force // true FP32 computation for consistent numerical results. CHECK_CUDNN(cudnnSetConvolutionMathType( (cudnnConvolutionDescriptor_t)conv_handle, CUDNN_FMA_MATH)); #endif CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim( (const cudnnConvolutionDescriptor_t)conv_handle, descriptor(data), (const cudnnFilterDescriptor_t)filter_handle, &out_num_samples, &out_k, &out_nr, &out_nc)); tensor_descriptor dest_desc; dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc); try { select_best_algorithms(data, dest_desc, allow_cache_use::yes); update_convolution_data_workspace_sizes(data, dest_desc); } catch (dlib::cudnn_error&) { // Sometimes the values stored in `config_to_algo_cache` do not quite work - // so let's get a fresh estimate, instead of using a cached value. select_best_algorithms(data, dest_desc, allow_cache_use::no); update_convolution_data_workspace_sizes(data, dest_desc); } } catch(...) { clear(); throw; } } tensor_conv:: ~tensor_conv ( ) { clear(); } void tensor_conv::operator() ( const bool add_to_output, resizable_tensor& output, const tensor& data, const tensor& filters ) { DLIB_CASSERT(stride_y > 0 && stride_x > 0, "You must call setup() before calling this function"); output.set_size(out_num_samples, out_k, out_nr, out_nc); (*this)(add_to_output, static_cast(output), data, filters); } void tensor_conv::operator() ( const bool add_to_output, tensor& output, const tensor& data, const tensor& filters ) { DLIB_CASSERT(is_same_object(output,data) == false); DLIB_CASSERT(is_same_object(output,filters) == false); DLIB_CASSERT(filters.k() == data.k()); DLIB_CASSERT(stride_y > 0 && stride_x > 0, "You must call setup() before calling this function"); DLIB_CASSERT(filters.nc() <= data.nc() + 2*padding_x, "Filter windows must be small enough to fit into the padded image." << "\n\t filters.nc(): " << filters.nc() << "\n\t data.nc(): " << data.nc() << "\n\t padding_x: " << padding_x ); DLIB_CASSERT(filters.nr() <= data.nr() + 2*padding_y, "Filter windows must be small enough to fit into the padded image." << "\n\t filters.nr(): " << filters.nr() << "\n\t data.nr(): " << data.nr() << "\n\t padding_y: " << padding_y ); DLIB_CASSERT(output.num_samples() == data.num_samples(),out_num_samples << " " << data.num_samples()); DLIB_CASSERT(output.k() == filters.num_samples()); DLIB_CASSERT(output.nr() == 1+(data.nr()+2*padding_y-filters.nr())/stride_y); DLIB_CASSERT(output.nc() == 1+(data.nc()+2*padding_x-filters.nc())/stride_x); const float alpha = 1; const float beta = add_to_output ? 1 : 0; // Since cudnnConvolutionForward() is an asynchronous call, we need to hold a // reference to the workspace buffer so we can be sure it isn't reallocated // while the function is still executing on the device. But each time we come // here, we make sure to grab the latest workspace buffer so that, globally, we // minimize the number of such buffers. forward_workspace = device_global_buffer(forward_workspace_size_in_bytes); CHECK_CUDNN(cudnnConvolutionForward( context(), &alpha, descriptor(data), data.device(), (const cudnnFilterDescriptor_t)filter_handle, filters.device(), (const cudnnConvolutionDescriptor_t)conv_handle, (cudnnConvolutionFwdAlgo_t)forward_algo, forward_workspace, forward_workspace_size_in_bytes, &beta, descriptor(output), output.device())); } void tensor_conv::operator() ( const bool add_to_output, resizable_tensor& output, const tensor& data, const tensor& filters, const tensor& biases, bool use_relu ) { DLIB_CASSERT(stride_y > 0 && stride_x > 0, "You must call setup() before calling this function"); output.set_size(out_num_samples, out_k, out_nr, out_nc); (*this)(add_to_output, static_cast(output), data, filters, biases, use_relu); } void tensor_conv::operator() ( const bool add_to_output, tensor& output, const tensor& data, const tensor& filters, const tensor& biases, bool use_relu ) { // Function cudnnConvolutionBiasActivationForward should only be called with CUDNN_ACTIVATION_IDENTITY when // the chosen forward algorithm is CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, as cuDNN documentation explicitly says. // In case the algorithm is different, perform the forward pass and bias addition separately. // If use_relu is true, any algorithm can be used. if (!use_relu && forward_algo != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) { (*this)(add_to_output, output, data, filters); tt::add(1, output, 1, biases); return; } DLIB_CASSERT(is_same_object(output,data) == false); DLIB_CASSERT(is_same_object(output,filters) == false); DLIB_CASSERT(filters.k() == data.k()); DLIB_CASSERT(stride_y > 0 && stride_x > 0, "You must call setup() before calling this function"); DLIB_CASSERT(filters.nc() <= data.nc() + 2*padding_x, "Filter windows must be small enough to fit into the padded image." << "\n\t filters.nc(): " << filters.nc() << "\n\t data.nc(): " << data.nc() << "\n\t padding_x: " << padding_x ); DLIB_CASSERT(filters.nr() <= data.nr() + 2*padding_y, "Filter windows must be small enough to fit into the padded image." << "\n\t filters.nr(): " << filters.nr() << "\n\t data.nr(): " << data.nr() << "\n\t padding_y: " << padding_y ); DLIB_CASSERT(output.num_samples() == data.num_samples(),out_num_samples << " " << data.num_samples()); DLIB_CASSERT(output.k() == filters.num_samples()); DLIB_CASSERT(output.nr() == 1+(data.nr()+2*padding_y-filters.nr())/stride_y); DLIB_CASSERT(output.nc() == 1+(data.nc()+2*padding_x-filters.nc())/stride_x); DLIB_CASSERT(filters.num_samples() == biases.k()); const float alpha1 = 1; const float alpha2 = add_to_output ? 1 : 0; // Since cudnnConvolutionBiasActivationForward() is an asynchronous call, // we need to hold a reference to the workspace buffer so we can be sure it // isn't reallocated while the function is still executing on the device. // But each time we come here, we make sure to grab the latest workspace // buffer so that, globally, we minimize the number of such buffers. forward_workspace = device_global_buffer(forward_workspace_size_in_bytes); float* out = output.device(); const cudnnTensorDescriptor_t out_desc = descriptor(output); CHECK_CUDNN(cudnnConvolutionBiasActivationForward( context(), &alpha1, descriptor(data), data.device(), (const cudnnFilterDescriptor_t)filter_handle, filters.device(), (const cudnnConvolutionDescriptor_t)conv_handle, (cudnnConvolutionFwdAlgo_t)forward_algo, forward_workspace, forward_workspace_size_in_bytes, &alpha2, out_desc, out, descriptor(biases), biases.device(), use_relu ? relu_activation_descriptor() : identity_activation_descriptor(), out_desc, out)); } void tensor_conv::get_gradient_for_data ( const bool add_to_output, const tensor& gradient_input, const tensor& filters, tensor& data_gradient ) { const float alpha = 1; const float beta = add_to_output ? 1 : 0; // Since cudnnConvolutionBackwardData() is an asynchronous call, we need to hold a // reference to the workspace buffer so we can be sure it isn't reallocated // while the function is still executing on the device. But each time we come // here, we make sure to grab the latest workspace buffer so that, globally, we // minimize the number of such buffers. backward_data_workspace = device_global_buffer(backward_data_workspace_size_in_bytes); CHECK_CUDNN(cudnnConvolutionBackwardData(context(), &alpha, (const cudnnFilterDescriptor_t)filter_handle, filters.device(), descriptor(gradient_input), gradient_input.device(), (const cudnnConvolutionDescriptor_t)conv_handle, (cudnnConvolutionBwdDataAlgo_t)backward_data_algo, backward_data_workspace, backward_data_workspace_size_in_bytes, &beta, descriptor(data_gradient), data_gradient.device())); } void tensor_conv:: get_gradient_for_filters ( const bool add_to_output, const tensor& gradient_input, const tensor& data, tensor& filters_gradient ) { const float alpha = 1; const float beta = add_to_output ? 1 : 0; // Since cudnnConvolutionBackwardFilter() is an asynchronous call, we need to hold a // reference to the workspace buffer so we can be sure it isn't reallocated // while the function is still executing on the device. But each time we come // here, we make sure to grab the latest workspace buffer so that, globally, we // minimize the number of such buffers. backward_filters_workspace = device_global_buffer(backward_filters_workspace_size_in_bytes); CHECK_CUDNN(cudnnConvolutionBackwardFilter(context(), &alpha, descriptor(data), data.device(), descriptor(gradient_input), gradient_input.device(), (const cudnnConvolutionDescriptor_t)conv_handle, (cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo, backward_filters_workspace, backward_filters_workspace_size_in_bytes, &beta, (const cudnnFilterDescriptor_t)filter_handle, filters_gradient.device())); } // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ pooling::pooling ( ) : handle(nullptr),window_height(0),window_width(0),stride_y(0),stride_x(0),padding_y(0), padding_x(0) { } pooling::~pooling( ) { clear(); } void pooling:: clear( ) { if (handle) cudnnDestroyPoolingDescriptor((cudnnPoolingDescriptor_t)handle); handle = nullptr; window_height = 0; window_width = 0; stride_y = 0; stride_x = 0; padding_y = 0; padding_x = 0; } void pooling:: setup_max_pooling( int window_height_, int window_width_, int stride_y_, int stride_x_, int padding_y_, int padding_x_ ) { setup(window_height_, window_width_, stride_y_, stride_x_, padding_y_, padding_x_, CUDNN_POOLING_MAX); do_max_pooling = true; } void pooling:: setup_avg_pooling( int window_height_, int window_width_, int stride_y_, int stride_x_, int padding_y_, int padding_x_ ) { setup(window_height_, window_width_, stride_y_, stride_x_, padding_y_, padding_x_, CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING); do_max_pooling = false; } void pooling:: setup( int window_height_, int window_width_, int stride_y_, int stride_x_, int padding_y_, int padding_x_, int pooling_mode ) { DLIB_CASSERT (window_height_ > 0 && window_width_ > 0 && stride_y_ > 0 && stride_x_ > 0 , "window_height_: " << window_height_ << "\t\n window_width_: " << window_width_ << "\t\n stride_y_: " << stride_y_ << "\t\n stride_x_: " << stride_x_ ); DLIB_CASSERT( 0 <= padding_y_ && padding_y_ < window_height_ && 0 <= padding_x_ && padding_x_ < window_width_, "window_height_: " << window_height_ << "\t\n window_width_: " << window_width_ << "\t\n padding_y_: " << padding_y_ << "\t\n padding_x_: " << padding_x_ ); if (window_height == window_height_ && window_width == window_width_ && stride_y == stride_y_ && stride_x == stride_x_ && padding_y == padding_y_ && padding_x == padding_x_ ) { return; } clear(); try { window_height = window_height_; window_width = window_width_; stride_x = stride_x_; stride_y = stride_y_; padding_y = padding_y_; padding_x = padding_x_; cudnnPoolingDescriptor_t poolingDesc; CHECK_CUDNN(cudnnCreatePoolingDescriptor(&poolingDesc)); handle = poolingDesc; CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc, (cudnnPoolingMode_t)pooling_mode, CUDNN_PROPAGATE_NAN, window_height, window_width, padding_y, padding_x, stride_y, stride_x)); } catch(...) { clear(); throw; } } void pooling:: operator() ( resizable_tensor& dest, const tensor& src ) { DLIB_CASSERT(window_width <= src.nc() + 2*padding_x, "Pooling windows must be small enough to fit into the padded image." << "\n\t window_width: " << window_width << "\n\t src.nc(): " << src.nc() << "\n\t padding_x: " << padding_x ); DLIB_CASSERT(window_height <= src.nr() + 2*padding_y, "Pooling windows must be small enough to fit into the padded image." << "\n\t window_height: " << window_height << "\n\t src.nr(): " << src.nr() << "\n\t padding_y: " << padding_y ); const float alpha = 1; const float beta = 0; int outN; int outC; int outH; int outW; CHECK_CUDNN(cudnnGetPooling2dForwardOutputDim((const cudnnPoolingDescriptor_t)handle, descriptor(src), &outN, &outC, &outH, &outW)); dest.set_size(outN,outC,outH,outW); DLIB_CASSERT(dest.num_samples() == src.num_samples()); DLIB_CASSERT(dest.k() == src.k()); DLIB_CASSERT(dest.nr() == 1 + (src.nr() + 2*padding_y - window_height)/stride_y, "\n stride_y: " << stride_y << "\n padding_y: " << padding_y << "\n window_height: " << window_height << "\n src.nr(): " << src.nr() << "\n dest.nr(): " << dest.nr() << "\n src.nr()/stride_y: " << src.nr()/stride_y); DLIB_CASSERT(dest.nc() == 1 + (src.nc() + 2*padding_x - window_width)/stride_x, "\n stride_x: " << stride_x << "\n padding_x: " << padding_x << "\n window_width: " << window_width << "\n src.nc(): " << src.nc() << "\n dest.nc(): " << dest.nc() << "\n src.nc()/stride_x: " << src.nc()/stride_x); CHECK_CUDNN(cudnnPoolingForward(context(), (const cudnnPoolingDescriptor_t)handle, &alpha, descriptor(src), src.device(), &beta, descriptor(dest), dest.device())); } void pooling::get_gradient( const tensor& gradient_input, const tensor& dest, const tensor& src, tensor& grad ) { DLIB_CASSERT(have_same_dimensions(gradient_input,dest)); DLIB_CASSERT(have_same_dimensions(src,grad)); const float alpha = 1; const float beta = 1; CHECK_CUDNN(cudnnPoolingBackward(context(), (const cudnnPoolingDescriptor_t)handle, &alpha, descriptor(dest), dest.device(), descriptor(gradient_input), gradient_input.device(), descriptor(src), src.device(), &beta, descriptor(grad), grad.device())); } // ------------------------------------------------------------------------------------ void softmax( tensor& dest, const tensor& src, operation_mode mode ) { DLIB_CASSERT(have_same_dimensions(dest, src)); if (src.size() == 0) return; const float alpha = 1; const float beta = 0; if (mode == operation_mode::CHANNEL_WISE) { CHECK_CUDNN(cudnnSoftmaxForward(context(), CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, descriptor(src), src.device(), &beta, descriptor(dest), dest.device())); } else if (mode == operation_mode::PLANE_WISE) { const long num_samples = src.num_samples(); const long num_channels = src.k(); const size_t plane_size = src.nr() * src.nc(); for (long s = 0; s < num_samples; ++s) { for (long k = 0; k < num_channels; ++k) { auto src_slice = src.device() + (s * num_channels + k) * plane_size; auto dest_slice = dest.device() + (s * num_channels + k) * plane_size; auto a_src_slice = alias_tensor(src.nr(), src.nc())(src, (s * num_channels + k) * plane_size); auto a_dest_slice = alias_tensor(dest.nr(), dest.nc())(dest, (s * num_channels + k) * plane_size); CHECK_CUDNN(cudnnSoftmaxForward(context(), CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, descriptor(a_src_slice), src_slice, &beta, descriptor(a_dest_slice), dest_slice)); } } } } void softmax_gradient( tensor& grad, const tensor& dest, const tensor& gradient_input, operation_mode mode ) { DLIB_CASSERT( have_same_dimensions(dest, gradient_input) == true && have_same_dimensions(dest, grad) == true); if (dest.size() == 0) return; const float alpha = 1; const float beta = is_same_object(grad, gradient_input) ? 0 : 1; if (mode == operation_mode::CHANNEL_WISE) { CHECK_CUDNN(cudnnSoftmaxBackward(context(), CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, descriptor(dest), dest.device(), descriptor(gradient_input), gradient_input.device(), &beta, descriptor(grad), grad.device())); } else if (mode == operation_mode::PLANE_WISE) { const long num_samples = dest.num_samples(); const long num_channels = dest.k(); const size_t plane_size = dest.nr() * dest.nc(); for (long s = 0; s < num_samples; ++s) { for (long k = 0; k < num_channels; ++k) { auto dest_slice = dest.device() + (s * num_channels + k) * plane_size; auto gi_slice = gradient_input.device() + (s * num_channels + k) * plane_size; auto grad_slice = grad.device() + (s * num_channels + k) * plane_size; auto a_dest_slice = alias_tensor(dest.nr(), dest.nc())(dest, (s * num_channels + k) * plane_size); auto a_gi_slice = alias_tensor(gradient_input.nr(), gradient_input.nc())(gradient_input, (s * num_channels + k) * plane_size); auto a_grad_slice = alias_tensor(grad.nr(), grad.nc())(grad, (s * num_channels + k) * plane_size); CHECK_CUDNN(cudnnSoftmaxBackward(context(), CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, descriptor(a_dest_slice), dest_slice, descriptor(a_gi_slice), gi_slice, &beta, descriptor(a_grad_slice), grad_slice)); } } } } // ------------------------------------------------------------------------------------ void softmax_all ( tensor& dest, const tensor& src ) { DLIB_CASSERT(have_same_dimensions(dest,src)); if (src.size() == 0) return; const float alpha = 1; const float beta = 0; CHECK_CUDNN(cudnnSoftmaxForward(context(), CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_INSTANCE, &alpha, descriptor(src), src.device(), &beta, descriptor(dest), dest.device())); } void softmax_all_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ) { DLIB_CASSERT( have_same_dimensions(dest,gradient_input) == true && have_same_dimensions(dest,grad) == true ); if (dest.size() == 0) return; const float alpha = 1; const float beta = is_same_object(grad,gradient_input) ? 0 : 1; CHECK_CUDNN(cudnnSoftmaxBackward(context(), CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_INSTANCE, &alpha, descriptor(dest), dest.device(), descriptor(gradient_input), gradient_input.device(), &beta, descriptor(grad), grad.device())); } // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ void sigmoid ( tensor& dest, const tensor& src ) { DLIB_CASSERT(have_same_dimensions(dest,src)); if (src.size() == 0) return; const float alpha = 1; const float beta = 0; CHECK_CUDNN(cudnnActivationForward(context(), sigmoid_activation_descriptor(), &alpha, descriptor(src), src.device(), &beta, descriptor(dest), dest.device())); } void sigmoid_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ) { DLIB_CASSERT( have_same_dimensions(dest,gradient_input) == true && have_same_dimensions(dest,grad) == true ); if (dest.size() == 0) return; const float alpha = 1; const float beta = is_same_object(grad,gradient_input) ? 0 : 1; CHECK_CUDNN(cudnnActivationBackward(context(), sigmoid_activation_descriptor(), &alpha, descriptor(dest), dest.device(), descriptor(gradient_input), gradient_input.device(), descriptor(dest), dest.device(), &beta, descriptor(grad), grad.device())); } // ------------------------------------------------------------------------------------ void relu ( tensor& dest, const tensor& src ) { DLIB_CASSERT(have_same_dimensions(dest,src)); if (src.size() == 0) return; const float alpha = 1; const float beta = 0; CHECK_CUDNN(cudnnActivationForward(context(), relu_activation_descriptor(), &alpha, descriptor(src), src.device(), &beta, descriptor(dest), dest.device())); } void relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ) { DLIB_CASSERT( have_same_dimensions(dest,gradient_input) == true && have_same_dimensions(dest,grad) == true ); if (dest.size() == 0) return; const float alpha = 1; const float beta = is_same_object(grad,gradient_input) ? 0 : 1; CHECK_CUDNN(cudnnActivationBackward(context(), relu_activation_descriptor(), &alpha, descriptor(dest), dest.device(), descriptor(gradient_input), gradient_input.device(), descriptor(dest), dest.device(), &beta, descriptor(grad), grad.device())); } // ------------------------------------------------------------------------------------ void tanh ( tensor& dest, const tensor& src ) { DLIB_CASSERT(have_same_dimensions(dest,src)); if (src.size() == 0) return; const float alpha = 1; const float beta = 0; CHECK_CUDNN(cudnnActivationForward(context(), tanh_activation_descriptor(), &alpha, descriptor(src), src.device(), &beta, descriptor(dest), dest.device())); } void tanh_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ) { DLIB_CASSERT( have_same_dimensions(dest,gradient_input) == true && have_same_dimensions(dest,grad) == true); if (dest.size() == 0) return; const float alpha = 1; const float beta = is_same_object(grad,gradient_input) ? 0 : 1; CHECK_CUDNN(cudnnActivationBackward(context(), tanh_activation_descriptor(), &alpha, descriptor(dest), dest.device(), descriptor(gradient_input), gradient_input.device(), descriptor(dest), dest.device(), &beta, descriptor(grad), grad.device())); } // ------------------------------------------------------------------------------------ } } #endif // DLIB_USE_CUDA #endif // DLIB_DNN_CuDNN_CPP_ ================================================ FILE: dlib/cuda/cudnn_dlibapi.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuDNN_H_ #define DLIB_DNN_CuDNN_H_ #include #include "operation_mode.h" #ifdef DLIB_USE_CUDA #include "cuda_errors.h" #include "cuda_data_ptr.h" #endif // DLIB_USE_CUDA namespace dlib { class tensor; class resizable_tensor; #ifdef DLIB_USE_CUDA namespace cuda { // ----------------------------------------------------------------------------------- class tensor_descriptor { /*! Each tensor object will carry a tensor_descriptor in it when compiled with CUDA. !*/ public: // not copyable tensor_descriptor(const tensor_descriptor&) = delete; tensor_descriptor& operator=(const tensor_descriptor&) = delete; // but is movable tensor_descriptor(tensor_descriptor&& item) : tensor_descriptor() { swap(item); } tensor_descriptor& operator=(tensor_descriptor&& item) { swap(item); return *this; } tensor_descriptor(); ~tensor_descriptor(); void set_size( int n, int k, int nr, int nc ); /*! ensures - if any of the arguments are 0 then they are all set to 0 in the tensor. !*/ void get_size ( int& n, int& k, int& nr, int& nc ) const; const void* get_handle ( ) const { return handle; } private: void swap(tensor_descriptor& item) { std::swap(handle, item.handle); } void* handle; }; // ------------------------------------------------------------------------------------ void add( float beta, tensor& dest, float alpha, const tensor& src ); // ------------------------------------------------------------------------------------ void assign_conv_bias_gradient ( tensor& grad, const tensor& gradient_input ); // ------------------------------------------------------------------------------------ void batch_normalize_inference ( const double eps, resizable_tensor& dest, const tensor& src, const tensor& gamma, const tensor& beta, const tensor& running_means, const tensor& running_variances ); void batch_normalize ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const double averaging_factor, resizable_tensor& running_means, resizable_tensor& running_variances, const tensor& src, const tensor& gamma, const tensor& beta ); void batch_normalize_gradient( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad ); // ------------------------------------------------------------------------------------ void batch_normalize_conv_inference ( const double eps, resizable_tensor& dest, const tensor& src, const tensor& gamma, const tensor& beta, const tensor& running_means, const tensor& running_variances ); void batch_normalize_conv ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const double averaging_factor, resizable_tensor& running_means, resizable_tensor& running_variances, const tensor& src, const tensor& gamma, const tensor& beta ); void batch_normalize_conv_gradient( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad ); // ------------------------------------------------------------------------------------ class tensor_conv { public: tensor_conv(const tensor_conv&) = delete; tensor_conv& operator=(const tensor_conv&) = delete; tensor_conv(); void clear( ); ~tensor_conv ( ); void operator() ( const bool add_to_output, resizable_tensor& output, const tensor& data, const tensor& filters ); void operator() ( const bool add_to_output, tensor& output, const tensor& data, const tensor& filters ); void operator() ( const bool add_to_output, resizable_tensor& output, const tensor& data, const tensor& filters, const tensor& biases, bool use_relu ); void operator() ( const bool add_to_output, tensor& output, const tensor& data, const tensor& filters, const tensor& biases, bool use_relu ); void get_gradient_for_data ( const bool add_to_output, const tensor& gradient_input, const tensor& filters, tensor& data_gradient ); void get_gradient_for_filters ( const bool add_to_output, const tensor& gradient_input, const tensor& data, tensor& filters_gradient ); void setup( const tensor& data, const tensor& filters, int stride_y, int stride_x, int padding_y, int padding_x ); void setup( const tensor& data, const tensor& filters, const tensor& biases, int stride_y, int stride_x, int padding_y, int padding_x ); private: // These variables record the type of data given to the last call to setup(). int stride_y; int stride_x; int padding_y; int padding_x; long data_num_samples, data_k, data_nr, data_nc; long filters_num_samples, filters_k, filters_nr, filters_nc; void* filter_handle; void* conv_handle; // dimensions of the output tensor from operator() int out_num_samples; int out_k; int out_nr; int out_nc; enum class allow_cache_use { no, yes }; // sets the three _algo fields. void select_best_algorithms(const tensor& data, const tensor_descriptor& dest_desc, allow_cache_use allow_cache_use); int forward_algo; int backward_data_algo; int backward_filters_algo; // sets the three _workspace_size_in_bytes fields. void update_convolution_data_workspace_sizes(const tensor& data, const tensor_descriptor& dest_desc); size_t forward_workspace_size_in_bytes; size_t backward_data_workspace_size_in_bytes; size_t backward_filters_workspace_size_in_bytes; cuda_data_void_ptr forward_workspace; cuda_data_void_ptr backward_data_workspace; cuda_data_void_ptr backward_filters_workspace; }; // ------------------------------------------------------------------------------------ class pooling { public: pooling(const pooling&) = delete; pooling& operator=(const pooling&) = delete; pooling ( ); ~pooling( ); void clear( ); void setup_max_pooling( int window_height, int window_width, int stride_y, int stride_x, int padding_y, int padding_x ); void setup_avg_pooling( int window_height, int window_width, int stride_y, int stride_x, int padding_y, int padding_x ); bool does_max_pooling( ) const { return do_max_pooling; } void operator() ( resizable_tensor& dest, const tensor& src ); void get_gradient( const tensor& gradient_input, const tensor& dest, const tensor& src, tensor& grad ); private: void setup( int window_height, int window_width, int stride_y, int stride_x, int padding_y, int padding_x, int pooling_mode ); void* handle; int window_height; int window_width; int stride_y; int stride_x; int padding_y; int padding_x; bool do_max_pooling; }; // ------------------------------------------------------------------------------------ void softmax ( tensor& dest, const tensor& src, operation_mode mode = operation_mode::CHANNEL_WISE ); void softmax_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, operation_mode mode = operation_mode::CHANNEL_WISE ); // ------------------------------------------------------------------------------------ void softmax_all ( tensor& dest, const tensor& src ); void softmax_all_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); // ------------------------------------------------------------------------------------ void sigmoid ( tensor& dest, const tensor& src ); void sigmoid_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); // ------------------------------------------------------------------------------------ void relu ( tensor& dest, const tensor& src ); void relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); // ------------------------------------------------------------------------------------ void tanh ( tensor& dest, const tensor& src ); void tanh_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); // ------------------------------------------------------------------------------------ } #endif // DLIB_USE_CUDA } #endif // DLIB_DNN_CuDNN_H_ ================================================ FILE: dlib/cuda/curand_dlibapi.cpp ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuRAND_CPP_ #define DLIB_DNN_CuRAND_CPP_ #ifdef DLIB_USE_CUDA #include "curand_dlibapi.h" #include #include "../string.h" static const char* curand_get_error_string(curandStatus_t s) { switch(s) { case CURAND_STATUS_NOT_INITIALIZED: return "CUDA Runtime API initialization failed."; case CURAND_STATUS_LENGTH_NOT_MULTIPLE: return "The requested length must be a multiple of two."; default: return "A call to cuRAND failed"; } } // Check the return value of a call to the cuDNN runtime for an error condition. #define CHECK_CURAND(call) \ do{ \ const curandStatus_t error = call; \ if (error != CURAND_STATUS_SUCCESS) \ { \ std::ostringstream sout; \ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ sout << "code: " << error << ", reason: " << curand_get_error_string(error);\ throw dlib::curand_error(sout.str()); \ } \ }while(false) namespace dlib { namespace cuda { // ---------------------------------------------------------------------------------------- curand_generator:: curand_generator( unsigned long long seed ) : handle(nullptr) { curandGenerator_t gen; CHECK_CURAND(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); handle = gen; CHECK_CURAND(curandSetPseudoRandomGeneratorSeed(gen, seed)); } curand_generator:: ~curand_generator() { if (handle) { curandDestroyGenerator((curandGenerator_t)handle); } } void curand_generator:: fill_gaussian ( tensor& data, float mean, float stddev ) { if (data.size() == 0) return; CHECK_CURAND(curandGenerateNormal((curandGenerator_t)handle, data.device(), data.size(), mean, stddev)); } void curand_generator:: fill_uniform ( tensor& data ) { if (data.size() == 0) return; CHECK_CURAND(curandGenerateUniform((curandGenerator_t)handle, data.device(), data.size())); } void curand_generator:: fill ( cuda_data_ptr& data ) { if (data.size() == 0) return; CHECK_CURAND(curandGenerate((curandGenerator_t)handle, data, data.size())); } // ----------------------------------------------------------------------------------- } } #endif // DLIB_USE_CUDA #endif // DLIB_DNN_CuRAND_CPP_ ================================================ FILE: dlib/cuda/curand_dlibapi.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuRAND_H_ #define DLIB_DNN_CuRAND_H_ #ifdef DLIB_USE_CUDA #include "tensor.h" #include "cuda_errors.h" #include "cuda_data_ptr.h" namespace dlib { namespace cuda { // ----------------------------------------------------------------------------------- class curand_generator { public: // not copyable curand_generator(const curand_generator&) = delete; curand_generator& operator=(const curand_generator&) = delete; curand_generator() : curand_generator(0) {} curand_generator(unsigned long long seed); ~curand_generator(); void fill ( cuda_data_ptr& data ); /*! ensures - Fills data with random 32-bit unsigned integers. !*/ void fill_gaussian ( tensor& data, float mean = 0, float stddev = 1 ); /*! requires - data.size()%2 == 0 - stddev >= 0 ensures - Fills data with random numbers drawn from a Gaussian distribution with the given mean and standard deviation. !*/ void fill_uniform ( tensor& data ); /*! ensures - Fills data with uniform random numbers in the range (0.0, 1.0]. !*/ private: void* handle; }; // ----------------------------------------------------------------------------------- } } #endif // DLIB_USE_CUDA #endif // DLIB_DNN_CuRAND_H_ ================================================ FILE: dlib/cuda/cusolver_dlibapi.cu ================================================ // Copyright (C) 2017 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuSOLVER_CU_ #define DLIB_DNN_CuSOLVER_CU_ #ifdef DLIB_USE_CUDA #include "cusolver_dlibapi.h" #include #include #include "cuda_utils.h" // ---------------------------------------------------------------------------------------- static const char* cusolver_get_error_string(cusolverStatus_t s) { switch(s) { case CUSOLVER_STATUS_NOT_INITIALIZED: return "CUDA Runtime API initialization failed."; case CUSOLVER_STATUS_ALLOC_FAILED: return "CUDA Resources could not be allocated."; default: return "A call to cuSolver failed"; } } // Check the return value of a call to the cuSolver runtime for an error condition. #define CHECK_CUSOLVER(call) \ do{ \ const cusolverStatus_t error = call; \ if (error != CUSOLVER_STATUS_SUCCESS) \ { \ std::ostringstream sout; \ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\ sout << "code: " << error << ", reason: " << cusolver_get_error_string(error);\ throw dlib::cusolver_error(sout.str()); \ } \ }while(false) // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- namespace dlib { namespace cuda { // ----------------------------------------------------------------------------------- class cusolver_context { public: // not copyable cusolver_context(const cusolver_context&) = delete; cusolver_context& operator=(const cusolver_context&) = delete; cusolver_context() { handles.resize(16); } ~cusolver_context() { for (auto h : handles) { if (h) cusolverDnDestroy(h); } } cusolverDnHandle_t get_handle ( ) { int new_device_id; CHECK_CUDA(cudaGetDevice(&new_device_id)); // make room for more devices if needed if (new_device_id >= (long)handles.size()) handles.resize(new_device_id+16); // If we don't have a handle already for this device then make one if (!handles[new_device_id]) CHECK_CUSOLVER(cusolverDnCreate(&handles[new_device_id])); // Finally, return the handle for the current device return handles[new_device_id]; } private: std::vector handles; }; static cusolverDnHandle_t context() { thread_local cusolver_context c; return c.get_handle(); } // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ __global__ void _cuda_set_to_identity_matrix(float* m, size_t nr) { for (auto j : grid_stride_range(0, nr*nr)) { if (j%(nr+1) == 0) m[j] = 1; else m[j] = 0; } } void set_to_identity_matrix ( tensor& m ) { DLIB_CASSERT(m.size() == m.num_samples()*m.num_samples()); launch_kernel(_cuda_set_to_identity_matrix, max_jobs(m.size()), m.device(), m.num_samples()); } // ------------------------------------------------------------------------------------ inv::~inv() { sync_if_needed(); } // ------------------------------------------------------------------------------------ void inv:: operator() ( const tensor& m_, resizable_tensor& out ) { DLIB_CASSERT(m_.size() == m_.num_samples()*m_.num_samples(), "Input matrix must be square if you want to invert it."); m = m_; out.copy_size(m); set_to_identity_matrix(out); const int nc = m.num_samples(); int Lwork; CHECK_CUSOLVER(cusolverDnSgetrf_bufferSize(context(), nc , nc, m.device(), nc, &Lwork)); if (Lwork > (int)workspace.size()) { sync_if_needed(); workspace = cuda_data_ptr(Lwork); } if (nc > (int)Ipiv.size()) { sync_if_needed(); Ipiv = cuda_data_ptr(nc); } if (info.size() != 1) { info = cuda_data_ptr(1); } CHECK_CUSOLVER(cusolverDnSgetrf(context(), nc, nc, m.device(), nc, workspace, Ipiv, info)); CHECK_CUSOLVER(cusolverDnSgetrs(context(), CUBLAS_OP_N, nc, nc, m.device(), nc, Ipiv, out.device(), nc, info)); did_work_lately = true; } // ------------------------------------------------------------------------------------ int inv:: get_last_status( ) { std::vector linfo; memcpy(linfo, info); if (linfo.size() != 0) return linfo[0]; else return 0; } // ------------------------------------------------------------------------------------ void inv:: sync_if_needed() { if (did_work_lately) { did_work_lately = false; // make sure we wait until any previous kernel launches have finished // before we do something like deallocate the GPU memory. cudaDeviceSynchronize(); } } // ------------------------------------------------------------------------------------ } } #endif // DLIB_USE_CUDA #endif // DLIB_DNN_CuSOLVER_CU_ ================================================ FILE: dlib/cuda/cusolver_dlibapi.h ================================================ // Copyright (C) 2017 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNN_CuSOLVER_H_ #define DLIB_DNN_CuSOLVER_H_ #ifdef DLIB_USE_CUDA #include "tensor.h" #include "cuda_errors.h" #include "cuda_data_ptr.h" #include "../noncopyable.h" namespace dlib { namespace cuda { // ----------------------------------------------------------------------------------- class inv : noncopyable { /*! WHAT THIS OBJECT REPRESENTS This is a functor for doing matrix inversion on the GPU. The only reason it's an object is to avoid the reallocation of some GPU memory blocks if you want to do a bunch of matrix inversions in a row. !*/ public: inv() = default; ~inv(); void operator() ( const tensor& m, resizable_tensor& out ); /*! requires - m.size() == m.num_samples()*m.num_samples() (i.e. mat(m) must be a square matrix) ensures - out == inv(mat(m)); !*/ int get_last_status( ); /*! ensures - returns 0 if the last matrix inversion was successful and != 0 otherwise. !*/ private: void sync_if_needed(); bool did_work_lately = false; resizable_tensor m; cuda_data_ptr workspace; cuda_data_ptr Ipiv; cuda_data_ptr info; }; // ------------------------------------------------------------------------------------ } } #endif // DLIB_USE_CUDA #endif // DLIB_DNN_CuSOLVER_H_ ================================================ FILE: dlib/cuda/gpu_data.cpp ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_GPU_DaTA_CPP_ #define DLIB_GPU_DaTA_CPP_ // Only things that require CUDA are declared in this cpp file. Everything else is in the // gpu_data.h header so that it can operate as "header-only" code when using just the CPU. #ifdef DLIB_USE_CUDA #include "gpu_data.h" #include #include "cuda_utils.h" #include #include namespace dlib { // ---------------------------------------------------------------------------------------- void memcpy ( gpu_data& dest, const gpu_data& src ) { DLIB_CASSERT(dest.size() == src.size()); if (src.size() == 0 || &dest == &src) return; memcpy(dest,0, src, 0, src.size()); } void memcpy ( gpu_data& dest, size_t dest_offset, const gpu_data& src, size_t src_offset, size_t num ) { DLIB_CASSERT(dest_offset + num <= dest.size()); DLIB_CASSERT(src_offset + num <= src.size()); if (num == 0) return; // if there is aliasing if (&dest == &src && std::max(dest_offset, src_offset) < std::min(dest_offset,src_offset)+num) { // if they perfectly alias each other then there is nothing to do if (dest_offset == src_offset) return; else std::memmove(dest.host()+dest_offset, src.host()+src_offset, sizeof(float)*num); } else { // if we write to the entire thing then we can use device_write_only() if (dest_offset == 0 && num == dest.size()) { // copy the memory efficiently based on which copy is current in each object. if (src.device_ready()) CHECK_CUDA(cudaMemcpy(dest.device_write_only(), src.device()+src_offset, num*sizeof(float), cudaMemcpyDeviceToDevice)); else CHECK_CUDA(cudaMemcpy(dest.device_write_only(), src.host()+src_offset, num*sizeof(float), cudaMemcpyHostToDevice)); } else { // copy the memory efficiently based on which copy is current in each object. if (dest.device_ready() && src.device_ready()) CHECK_CUDA(cudaMemcpy(dest.device()+dest_offset, src.device()+src_offset, num*sizeof(float), cudaMemcpyDeviceToDevice)); else if (!dest.device_ready() && src.device_ready()) CHECK_CUDA(cudaMemcpy(dest.host()+dest_offset, src.device()+src_offset, num*sizeof(float), cudaMemcpyDeviceToHost)); else if (dest.device_ready() && !src.device_ready()) CHECK_CUDA(cudaMemcpy(dest.device()+dest_offset, src.host()+src_offset, num*sizeof(float), cudaMemcpyHostToDevice)); else CHECK_CUDA(cudaMemcpy(dest.host()+dest_offset, src.host()+src_offset, num*sizeof(float), cudaMemcpyHostToHost)); } } } // ---------------------------------------------------------------------------------------- void synchronize_stream(cudaStream_t stream) { #if !defined CUDA_VERSION #error CUDA_VERSION not defined #elif CUDA_VERSION >= 9020 && CUDA_VERSION < 11000 // We will stop using this alternative version with cuda V11, hopefully the bug in // cudaStreamSynchronize is fixed by then. // // This should be pretty much the same as cudaStreamSynchronize, which for some // reason makes training freeze in some cases. // (see https://github.com/davisking/dlib/issues/1513) while (true) { cudaError_t err = cudaStreamQuery(stream); switch (err) { case cudaSuccess: return; // now we are synchronized case cudaErrorNotReady: break; // continue waiting default: CHECK_CUDA(err); // unexpected error: throw } } #else // CUDA_VERSION CHECK_CUDA(cudaStreamSynchronize(stream)); #endif // CUDA_VERSION } void gpu_data:: wait_for_transfer_to_finish() const { if (have_active_transfer) { synchronize_stream((cudaStream_t)cuda_stream.get()); have_active_transfer = false; // Check for errors. These calls to cudaGetLastError() are what help us find // out if our kernel launches have been failing. CHECK_CUDA(cudaGetLastError()); } } void gpu_data:: copy_to_device() const { // We want transfers to the device to always be concurrent with any device // computation. So we use our non-default stream to do the transfer. async_copy_to_device(); wait_for_transfer_to_finish(); } void gpu_data:: copy_to_host() const { if (!host_current) { wait_for_transfer_to_finish(); CHECK_CUDA(cudaMemcpy(data_host.get(), data_device.get(), data_size*sizeof(float), cudaMemcpyDeviceToHost)); host_current = true; // At this point we know our RAM block isn't in use because cudaMemcpy() // implicitly syncs with the device. device_in_use = false; // Check for errors. These calls to cudaGetLastError() are what help us find // out if our kernel launches have been failing. CHECK_CUDA(cudaGetLastError()); } } void gpu_data:: async_copy_to_device() const { if (!device_current) { if (device_in_use) { // Wait for any possible CUDA kernels that might be using our memory block to // complete before we overwrite the memory. synchronize_stream(0); device_in_use = false; } CHECK_CUDA(cudaMemcpyAsync(data_device.get(), data_host.get(), data_size*sizeof(float), cudaMemcpyHostToDevice, (cudaStream_t)cuda_stream.get())); have_active_transfer = true; device_current = true; } } void gpu_data:: set_size( size_t new_size ) { if (new_size == 0) { if (device_in_use) { // Wait for any possible CUDA kernels that might be using our memory block to // complete before we free the memory. synchronize_stream(0); device_in_use = false; } wait_for_transfer_to_finish(); data_size = 0; host_current = true; device_current = true; device_in_use = false; data_host.reset(); data_device.reset(); } else if (new_size != data_size) { if (device_in_use) { // Wait for any possible CUDA kernels that might be using our memory block to // complete before we free the memory. synchronize_stream(0); device_in_use = false; } wait_for_transfer_to_finish(); data_size = new_size; host_current = true; device_current = true; device_in_use = false; try { CHECK_CUDA(cudaGetDevice(&the_device_id)); // free memory blocks before we allocate new ones. data_host.reset(); data_device.reset(); void* data; CHECK_CUDA(cudaMallocHost(&data, new_size*sizeof(float))); // Note that we don't throw exceptions since the free calls are invariably // called in destructors. They also shouldn't fail anyway unless someone // is resetting the GPU card in the middle of their program. data_host.reset((float*)data, [](float* ptr){ auto err = cudaFreeHost(ptr); if(err!=cudaSuccess) std::cerr << "cudaFreeHost() failed. Reason: " << cudaGetErrorString(err) << std::endl; }); CHECK_CUDA(cudaMalloc(&data, new_size*sizeof(float))); data_device.reset((float*)data, [](float* ptr){ auto err = cudaFree(ptr); if(err!=cudaSuccess) std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl; }); if (!cuda_stream) { cudaStream_t cstream; CHECK_CUDA(cudaStreamCreateWithFlags(&cstream, cudaStreamNonBlocking)); cuda_stream.reset(cstream, [](void* ptr){ auto err = cudaStreamDestroy((cudaStream_t)ptr); if(err!=cudaSuccess) std::cerr << "cudaStreamDestroy() failed. Reason: " << cudaGetErrorString(err) << std::endl; }); } } catch(...) { set_size(0); throw; } } } // ---------------------------------------------------------------------------------------- } #endif // DLIB_USE_CUDA #endif // DLIB_GPU_DaTA_CPP_ ================================================ FILE: dlib/cuda/gpu_data.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_GPU_DaTA_H_ #define DLIB_GPU_DaTA_H_ #include "gpu_data_abstract.h" #include #include #include "cuda_errors.h" #include "../serialize.h" namespace dlib { // ---------------------------------------------------------------------------------------- class gpu_data { /*! CONVENTION - if (size() != 0) then - data_host == a pointer to size() floats in CPU memory. - if (data_device) then - data_device == a pointer to size() floats in device memory. - if (there might be an active async transfer from host to device) then - have_active_transfer == true - We use the host_current and device_current bools to keep track of which copy of the data (or both) are most current. e.g. if the CPU has modified the data and it hasn't been copied to the device yet then host_current==true and device_current==false. Similarly, we use device_in_use==true to indicate that device() has been called and no operation to wait for all CUDA kernel completion has been executed. So if device_in_use==true then there might be a CUDA kernel executing that is using the device memory block contained in this object. !*/ public: gpu_data( ) : data_size(0), host_current(true), device_current(true),have_active_transfer(false),device_in_use(false), the_device_id(0) { } // Not copyable gpu_data(const gpu_data&) = delete; gpu_data& operator=(const gpu_data&) = delete; // but is movable gpu_data(gpu_data&& item) : gpu_data() { swap(item); } gpu_data& operator=(gpu_data&& item) { swap(item); return *this; } int device_id() const { return the_device_id; } #ifdef DLIB_USE_CUDA void async_copy_to_device() const; void set_size(size_t new_size); #else // Note that calls to host() or device() will block until any async transfers are complete. void async_copy_to_device() const{} void set_size(size_t new_size) { if (new_size == 0) { data_size = 0; host_current = true; device_current = true; device_in_use = false; data_host.reset(); data_device.reset(); } else if (new_size != data_size) { data_size = new_size; host_current = true; device_current = true; device_in_use = false; data_host.reset(new float[new_size], std::default_delete()); data_device.reset(); } } #endif const float* host() const { copy_to_host(); return data_host.get(); } float* host() { copy_to_host(); device_current = false; return data_host.get(); } float* host_write_only() { host_current = true; device_current = false; return data_host.get(); } const float* device() const { #ifndef DLIB_USE_CUDA DLIB_CASSERT(false, "CUDA NOT ENABLED"); #endif copy_to_device(); device_in_use = true; return data_device.get(); } float* device() { #ifndef DLIB_USE_CUDA DLIB_CASSERT(false, "CUDA NOT ENABLED"); #endif copy_to_device(); host_current = false; device_in_use = true; return data_device.get(); } float* device_write_only() { #ifndef DLIB_USE_CUDA DLIB_CASSERT(false, "CUDA NOT ENABLED"); #endif wait_for_transfer_to_finish(); host_current = false; device_current = true; device_in_use = true; return data_device.get(); } bool host_ready ( ) const { return host_current; } bool device_ready ( ) const { return device_current && !have_active_transfer; } size_t size() const { return data_size; } void swap (gpu_data& item) { std::swap(data_size, item.data_size); std::swap(host_current, item.host_current); std::swap(device_current, item.device_current); std::swap(have_active_transfer, item.have_active_transfer); std::swap(data_host, item.data_host); std::swap(data_device, item.data_device); std::swap(cuda_stream, item.cuda_stream); std::swap(the_device_id, item.the_device_id); } private: #ifdef DLIB_USE_CUDA void copy_to_device() const; void copy_to_host() const; void wait_for_transfer_to_finish() const; #else void copy_to_device() const{} void copy_to_host() const{} void wait_for_transfer_to_finish() const{} #endif size_t data_size; mutable bool host_current; mutable bool device_current; mutable bool have_active_transfer; mutable bool device_in_use; std::shared_ptr data_host; std::shared_ptr data_device; std::shared_ptr cuda_stream; int the_device_id; }; inline void serialize(const gpu_data& item, std::ostream& out) { int version = 1; serialize(version, out); serialize(item.size(), out); auto data = item.host(); for (size_t i = 0; i < item.size(); ++i) serialize(data[i], out); } inline void deserialize(gpu_data& item, std::istream& in) { int version; deserialize(version, in); if (version != 1) throw serialization_error("Unexpected version found while deserializing dlib::gpu_data."); size_t s; deserialize(s, in); item.set_size(s); auto data = item.host(); for (size_t i = 0; i < item.size(); ++i) deserialize(data[i], in); } #ifdef DLIB_USE_CUDA void memcpy (gpu_data& dest, const gpu_data& src); void memcpy ( gpu_data& dest, size_t dest_offset, const gpu_data& src, size_t src_offset, size_t num ); #else inline void memcpy (gpu_data& dest, const gpu_data& src) { DLIB_CASSERT(dest.size() == src.size()); if (src.size() == 0 || &dest == &src) return; std::memcpy(dest.host_write_only(), src.host(), sizeof(float)*src.size()); } inline void memcpy ( gpu_data& dest, size_t dest_offset, const gpu_data& src, size_t src_offset, size_t num ) { DLIB_CASSERT(dest_offset + num <= dest.size()); DLIB_CASSERT(src_offset + num <= src.size()); if (num == 0) return; if (&dest == &src && std::max(dest_offset, src_offset) < std::min(dest_offset,src_offset)+num) { // if they perfectly alias each other then there is nothing to do if (dest_offset == src_offset) return; else std::memmove(dest.host()+dest_offset, src.host()+src_offset, sizeof(float)*num); } else { // if we write to the entire thing then we can use host_write_only() if (dest_offset == 0 && num == dest.size()) std::memcpy(dest.host_write_only(), src.host()+src_offset, sizeof(float)*num); else std::memcpy(dest.host()+dest_offset, src.host()+src_offset, sizeof(float)*num); } } #endif // ---------------------------------------------------------------------------------------- } #endif // DLIB_GPU_DaTA_H_ ================================================ FILE: dlib/cuda/gpu_data_abstract.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_GPU_DaTA_ABSTRACT_H_ #ifdef DLIB_GPU_DaTA_ABSTRACT_H_ #include "cuda_errors.h" #include "../serialize.h" namespace dlib { // ---------------------------------------------------------------------------------------- class gpu_data { /*! WHAT THIS OBJECT REPRESENTS This object is a block of size() floats, all stored contiguously in memory. Importantly, it keeps two copies of the floats, one on the host CPU side and another on the GPU device side. It automatically performs the necessary host/device transfers to keep these two copies of the data in sync. All transfers to the device happen asynchronously with respect to the default CUDA stream so that CUDA kernel computations can overlap with data transfers. However, any transfers from the device to the host happen synchronously in the default CUDA stream. Therefore, you should perform all your CUDA kernel launches on the default stream so that transfers back to the host do not happen before the relevant computations have completed. If DLIB_USE_CUDA is not #defined then this object will not use CUDA at all. Instead, it will simply store one host side memory block of floats. THREAD SAFETY Instances of this object are not thread-safe. So don't touch one from multiple threads at the same time. !*/ public: gpu_data( ); /*! ensures - #size() == 0 - #host() == nullptr - #device() == nullptr - #host_ready() == true - #device_ready() == true - #device_id() == 0 !*/ // This object is not copyable, however, it is movable. gpu_data(const gpu_data&) = delete; gpu_data& operator=(const gpu_data&) = delete; gpu_data(gpu_data&& item); gpu_data& operator=(gpu_data&& item); int device_id( ) const; /*! ensures - returns the ID of the CUDA device that allocated this memory. I.e. the number returned by cudaGetDevice() when the memory was allocated. - If CUDA is not being used then this function always returns 0. !*/ void async_copy_to_device( ); /*! ensures - if (!device_ready()) then - Begins asynchronously copying host data to the device once it is safe to do so. I.e. This function will wait until any previously scheduled CUDA kernels, which are using the device() memory block, have completed before transferring the new data to the device. - A call to device() that happens before the transfer completes will block until the transfer is complete. That is, it is safe to call async_copy_to_device() and then immediately call device(). !*/ void set_size( size_t new_size ); /*! ensures - #size() == new_size !*/ bool host_ready ( ) const; /*! ensures - returns true if and only if the host's copy of the data is current. The host's data is current if there aren't any modifications to the data which were made on the device side that have yet to be copied to the host. !*/ bool device_ready ( ) const; /*! ensures - returns true if and only if the device's copy of the data is current. The device's data is current if there aren't any modifications to the data which were made on the host side that have yet to be copied to the device. !*/ const float* host( ) const; /*! ensures - returns a pointer to the host memory block of size() contiguous float values or nullptr if size()==0. - if (!host_ready()) then - copies the data from the device to the host, while this is happening the call to host() blocks. - #host_ready() == true !*/ float* host( ); /*! ensures - returns a pointer to the host memory block of size() contiguous float values or nullptr if size()==0. - if (!host_ready()) then - copies the data from the device to the host, while this is happening the call to host() blocks. - #host_ready() == true - #device_ready() == false I.e. Marks the device side data as out of date so that the next call to device() will perform a host to device transfer. If you want to begin the transfer immediately then you can call async_copy_to_device() after calling host(). !*/ float* host_write_only( ); /*! ensures - This function returns the same pointer as host(), except that it never performs a device to host memory copy. Instead, it immediately marks the device side data as out of date, effectively discarding it. Therefore, the values in the data pointed to by host_write_only() are undefined and you should only call host_write_only() if you are going to assign to every memory location in the returned memory block. - #host_ready() == true - #device_ready() == false !*/ const float* device( ) const; /*! requires - DLIB_USE_CUDA is #defined ensures - returns a pointer to the device memory block of size() contiguous float values or nullptr if size()==0. - if (!device_ready()) then - copies the data from the host to the device, while this is happening the call to device() blocks. - #device_ready() == true !*/ float* device( ); /*! requires - DLIB_USE_CUDA is #defined ensures - returns a pointer to the device memory block of size() contiguous float values or nullptr if size()==0. - if (!device_ready()) then - copies the data from the host to the device, while this is happening the call to device() blocks. - #host_ready() == false - #device_ready() == true !*/ float* device_write_only( ); /*! requires - DLIB_USE_CUDA is #defined ensures - This function returns the same pointer as device(), except that it never performs a host to device memory copy. Instead, it immediately marks the host side data as out of date, effectively discarding it. Therefore, the values in the data pointed to by device_write_only() are undefined and you should only call device_write_only() if you are going to assign to every memory location in the returned memory block. - #host_ready() == false - #device_ready() == true !*/ size_t size( ) const; /*! ensures - returns the number of floats contained in this object. !*/ void swap ( gpu_data& item ); /*! ensures - swaps the state of *this and item !*/ }; void serialize(const gpu_data& item, std::ostream& out); void deserialize(gpu_data& item, std::istream& in); /*! provides serialization support !*/ void memcpy ( gpu_data& dest, const gpu_data& src ); /*! requires - dest.size() == src.size() ensures - Copies the data in src to dest. If the device data is current (i.e. device_ready()==true) on both src and dest then the copy will happen entirely on the device side. - It doesn't matter what GPU device is selected by cudaSetDevice(). You can always copy gpu_data objects to and from each other regardless. - This function blocks until the copy has completed. !*/ void memcpy ( gpu_data& dest, size_t dest_offset, const gpu_data& src, size_t src_offset, size_t num ); /*! requires - dest_offset + num <= dest.size() - src_offset + num <= src.size() ensures - Copies the data in src to dest, but only copies data in the range [src.host()+src_offset, src.host()+src_offset+num) to [dest.host()+dest_offset, dest.host()+dest_offset+num). Therefore, it is just like the above memcpy() except that you can specify some subset of data in a gpu_data object to be copied. - Like the above version of memcpy(), the copy will happen in the most efficient way, automatically using the appropriate type of host/device transfers based on where data is currently resident. - It doesn't matter what GPU device is selected by cudaSetDevice(). You can always copy gpu_data objects to and from each other regardless. - This function blocks until the copy has completed. !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_GPU_DaTA_ABSTRACT_H_ ================================================ FILE: dlib/cuda/operation_mode.h ================================================ // Copyright (C) 2024 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CUDA_OPERATION_MODE_H #define DLIB_CUDA_OPERATION_MODE_H namespace dlib { // ---------------------------------------------------------------------------------------- /*! This enum is used to determine the mode of operation for certain functions (such as gemm and softmax) in Dlib. It specifies whether the calculation should be performed based on the matrix field in nr()xnc() or if the matrix should be considered in num_samples()xk(). This helps in organizing tensor computations more efficiently according to the required dimensions. */ enum class operation_mode { CHANNEL_WISE = 0, PLANE_WISE = 1 }; // ---------------------------------------------------------------------------------------- } // namespace dlib #endif // DLIB_CUDA_OPERATION_MODE_H ================================================ FILE: dlib/cuda/tensor.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNn_TENSOR_H_ #define DLIB_DNn_TENSOR_H_ #include "tensor_abstract.h" #include #include "../matrix.h" #include "cudnn_dlibapi.h" #include "gpu_data.h" #include "../byte_orderer.h" #include #include "../any.h" namespace dlib { // ---------------------------------------------------------------------------------------- class tensor; namespace cuda { void set_tensor ( tensor& t, float value ); void scale_tensor ( tensor& t, float value ); } // ---------------------------------------------------------------------------------------- class tensor { public: tensor ( ) : m_n(0), m_k(0), m_nr(0), m_nc(0), m_size(0) { } virtual ~tensor() {} long long num_samples() const { return m_n; } long long k() const { return m_k; } long long nr() const { return m_nr; } long long nc() const { return m_nc; } size_t size() const { return m_size; } typedef float* iterator; typedef const float* const_iterator; iterator begin() { return host(); } const_iterator begin() const { return host(); } iterator end() { return host()+size(); } const_iterator end() const { return host()+size(); } void async_copy_to_device() const { data().async_copy_to_device(); } virtual const float* host() const = 0; virtual float* host() = 0; virtual float* host_write_only() = 0; virtual const float* device() const = 0; virtual float* device() = 0; virtual float* device_write_only() = 0; virtual const any& annotation() const = 0; virtual any& annotation() = 0; int device_id() const { return data().device_id(); } tensor& operator= (float val) { #ifdef DLIB_USE_CUDA // If you are using CUDA then presumably you will be mostly using tensors on // the GPU. So unless you seem to be actively working with the host side's // data then we do this initialization on the device side since this avoids a // host to device transfer that would likely immediately follow. if (data().device_ready()) { cuda::set_tensor(*this, val); return *this; } #endif auto d = host_write_only(); for (size_t i = 0; i < size(); ++i) d[i] = val; return *this; } tensor& operator*= (float val) { #ifdef DLIB_USE_CUDA cuda::scale_tensor(*this, val); return *this; #else for (auto& d : *this) d *= val; return *this; #endif } tensor& operator/= (float val) { *this *= 1.f/val; return *this; } template tensor& operator= (const matrix_exp& item) { DLIB_CASSERT(num_samples() == item.nr() && nr()*nc()*k() == item.nc()); static_assert((is_same_type::value == true), "To assign a matrix to a tensor the matrix must contain float values"); set_ptrm(host_write_only(), m_n, m_nr*m_nc*m_k) = item; return *this; } template tensor& operator+= (const matrix_exp& item) { DLIB_CASSERT(num_samples() == item.nr() && nr()*nc()*k() == item.nc()); static_assert((is_same_type::value == true), "To assign a matrix to a tensor the matrix must contain float values"); set_ptrm(host(), m_n, m_nr*m_nc*m_k) += item; return *this; } template tensor& operator-= (const matrix_exp& item) { DLIB_CASSERT(num_samples() == item.nr() && nr()*nc()*k() == item.nc()); static_assert((is_same_type::value == true), "To assign a matrix to a tensor the matrix must contain float values"); set_ptrm(host(), m_n, m_nr*m_nc*m_k) -= item; return *this; } template void set_sample ( unsigned long long idx, const matrix_exp& item ) { DLIB_CASSERT(idx < (unsigned long long)num_samples()); DLIB_CASSERT(item.size() == nr()*nc()*k()); static_assert((is_same_type::value == true), "To assign a matrix to a tensor the matrix must contain float values"); set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) = item; } template void add_to_sample ( unsigned long long idx, const matrix_exp& item ) { DLIB_CASSERT(idx < (unsigned long long)num_samples()); DLIB_CASSERT(item.size() == nr()*nc()*k()); static_assert((is_same_type::value == true), "To assign a matrix to a tensor the matrix must contain float values"); set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) += item; } #ifdef DLIB_USE_CUDA virtual const cuda::tensor_descriptor& get_cudnn_tensor_descriptor ( ) const = 0; #endif friend void memcpy ( tensor& dest, const tensor& src ) { DLIB_CASSERT(dest.size() == src.size()); memcpy(dest.data(), dest.get_alias_offset(), src.data(), src.get_alias_offset(), src.size()); } protected: friend class alias_tensor; virtual gpu_data& data() = 0; virtual const gpu_data& data() const = 0; virtual size_t get_alias_offset() const { return 0; } // needed by alias_tensor. long long m_n; long long m_k; long long m_nr; long long m_nc; long long m_size; // always equal to m_n*m_k*m_nr*m_nc }; // ---------------------------------------------------------------------------------------- inline bool is_vector ( const tensor& t ) { return t.size() == (size_t)t.num_samples() || t.size() == (size_t)t.k() || t.size() == (size_t)t.nr() || t.size() == (size_t)t.nc(); } // ---------------------------------------------------------------------------------------- inline const matrix_op > mat ( const tensor& t, long long nr, long long nc ) { DLIB_ASSERT(nr >= 0 && nc >= 0 , "\tconst matrix_exp mat(tensor, nr, nc)" << "\n\t nr and nc must be >= 0" << "\n\t nr: " << nr << "\n\t nc: " << nc ); DLIB_ASSERT(nr*nc == (long long)t.size() , "\tconst matrix_exp mat(tensor, nr, nc)" << "\n\t The sizes don't match up." << "\n\t nr*nc: " << nr*nc << "\n\t t.size(): " << t.size() ); typedef op_pointer_to_mat op; return matrix_op(op(t.host(),nr,nc)); } inline const matrix_op > mat ( const tensor& t ) { if (t.size() != 0) return mat(t, t.num_samples(), t.size()/t.num_samples()); else return mat((float*)0,0,0); } inline const matrix_op > image_plane ( const tensor& t, long long sample = 0, long long k = 0 ) { DLIB_ASSERT(0 <= sample && sample < t.num_samples() && 0 <= k && k < t.k() && t.size() != 0, "\tconst matrix_exp image_plane(tensor,sample,k)" << "\n\t Invalid arguments were given to this function." << "\n\t sample: " << sample << "\n\t k: " << k << "\n\t t.num_samples(): " << t.num_samples() << "\n\t t.k(): " << t.k() << "\n\t t.size(): " << t.size() ); typedef op_pointer_to_mat op; return matrix_op(op(t.host() + ((sample*t.k() + k)*t.nr())*t.nc(), t.nr(), t.nc())); } // ---------------------------------------------------------------------------------------- inline bool have_same_dimensions ( const tensor& a, const tensor& b ) { return a.num_samples() == b.num_samples() && a.k() == b.k() && a.nr() == b.nr() && a.nc() == b.nc(); } // ---------------------------------------------------------------------------------------- class resizable_tensor : public tensor { public: resizable_tensor( ) {} template resizable_tensor( const matrix_exp& item ) { set_size(item.nr(), item.nc()); *this = item; } explicit resizable_tensor( long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1 ) { DLIB_ASSERT( n_ >= 0 && k_ >= 0 && nr_ >= 0 && nc_ >= 0); set_size(n_,k_,nr_,nc_); } resizable_tensor(const resizable_tensor& item) : _annotation(item.annotation()) { copy_size(item); memcpy(*this, item); } resizable_tensor(const tensor& item) : _annotation(item.annotation()) { copy_size(item); memcpy(*this, item); } resizable_tensor(resizable_tensor&& item) { swap(item); } resizable_tensor& operator=(resizable_tensor&& item) { swap(item); return *this; } virtual const float* host() const { return data_instance.host(); } virtual float* host() { return data_instance.host(); } virtual float* host_write_only() { return data_instance.host_write_only(); } virtual const float* device() const { return data_instance.device(); } virtual float* device() { return data_instance.device(); } virtual float* device_write_only() { return data_instance.device_write_only(); } virtual const any& annotation() const { return _annotation; } virtual any& annotation() { return _annotation; } void clear( ) { set_size(0,0,0,0); _annotation.clear(); // free underlying memory data_instance.set_size(0); } void copy_size ( const tensor& item ) { set_size(item.num_samples(), item.k(), item.nr(), item.nc()); } resizable_tensor& operator= (float val) { tensor::operator=(val); return *this; } template resizable_tensor& operator= ( const matrix_exp& item ) { if (!(num_samples() == item.nr() && k()*nr()*nc() == item.nc())) set_size(item.nr(), item.nc()); tensor::operator=(item); return *this; } void set_size( long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1 ) { DLIB_ASSERT( n_ >= 0 && k_ >= 0 && nr_ >= 0 && nc_ >= 0); m_n = n_; m_k = k_; m_nr = nr_; m_nc = nc_; m_size = n_*k_*nr_*nc_; if ((long long)data_instance.size() < m_size) data_instance.set_size(m_size); #ifdef DLIB_USE_CUDA cudnn_descriptor.set_size(m_n,m_k,m_nr,m_nc); #endif } resizable_tensor& operator= (const resizable_tensor& item) { resizable_tensor temp(item); temp.swap(*this); return *this; } resizable_tensor& operator= (const tensor& item) { resizable_tensor temp(item); temp.swap(*this); return *this; } void swap(resizable_tensor& item) { std::swap(m_n, item.m_n); std::swap(m_k, item.m_k); std::swap(m_nr, item.m_nr); std::swap(m_nc, item.m_nc); std::swap(m_size, item.m_size); std::swap(data_instance, item.data_instance); std::swap(_annotation, item._annotation); #ifdef DLIB_USE_CUDA std::swap(cudnn_descriptor, item.cudnn_descriptor); #endif } #ifdef DLIB_USE_CUDA virtual const cuda::tensor_descriptor& get_cudnn_tensor_descriptor ( ) const { return cudnn_descriptor; } #endif private: #ifdef DLIB_USE_CUDA cuda::tensor_descriptor cudnn_descriptor; #endif gpu_data data_instance; any _annotation; virtual gpu_data& data() { return data_instance; } virtual const gpu_data& data() const { return data_instance; } }; inline void serialize(const tensor& item, std::ostream& out) { int version = 2; serialize(version, out); serialize(item.num_samples(), out); serialize(item.k(), out); serialize(item.nr(), out); serialize(item.nc(), out); byte_orderer bo; auto sbuf = out.rdbuf(); for (auto d : item) { // Write out our data as 4byte little endian IEEE floats rather than using // dlib's default float serialization. We do this because it will result in // more compact outputs. It's slightly less portable but it seems doubtful // that any CUDA enabled platform isn't going to use IEEE floats. But if one // does we can just update the serialization code here to handle it if such a // platform is encountered. bo.host_to_little(d); static_assert(sizeof(d)==4, "This serialization code assumes we are writing 4 byte floats"); sbuf->sputn((char*)&d, sizeof(d)); } } inline void deserialize(resizable_tensor& item, std::istream& in) { int version; deserialize(version, in); if (version != 2) throw serialization_error("Unexpected version found while deserializing dlib::resizable_tensor."); long long num_samples=0, k=0, nr=0, nc=0; deserialize(num_samples, in); deserialize(k, in); deserialize(nr, in); deserialize(nc, in); item.set_size(num_samples, k, nr, nc); byte_orderer bo; auto sbuf = in.rdbuf(); for (auto& d : item) { static_assert(sizeof(d)==4, "This serialization code assumes we are writing 4 byte floats"); if (sbuf->sgetn((char*)&d,sizeof(d)) != sizeof(d)) { in.setstate(std::ios::badbit); throw serialization_error("Error reading data while deserializing dlib::resizable_tensor."); } bo.little_to_host(d); } } // ---------------------------------------------------------------------------------------- inline double dot( const tensor& a, const tensor& b ) { DLIB_CASSERT(a.size() == b.size()); const float* da = a.host(); const float* db = b.host(); double sum = 0; for (size_t i = 0; i < a.size(); ++i) sum += da[i]*db[i]; return sum; } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- class alias_tensor_instance : public tensor { alias_tensor_instance( ) : data_instance(0), _annotation(0), data_offset(0) {} public: friend class alias_tensor; friend class alias_tensor_const_instance; alias_tensor_instance& operator= (float val) { tensor::operator=(val); return *this; } template alias_tensor_instance& operator= (const matrix_exp& item) { tensor::operator=(item); return *this; } virtual const float* host() const { return data_instance->host()+data_offset; } virtual float* host() { return data_instance->host()+data_offset; } virtual float* host_write_only() { return data_instance->host()+data_offset; } virtual const float* device() const { return data_instance->device()+data_offset; } virtual float* device() { return data_instance->device()+data_offset; } virtual float* device_write_only() { return data_instance->device()+data_offset; } virtual const any& annotation() const { return *_annotation; } virtual any& annotation() { return *_annotation; } #ifdef DLIB_USE_CUDA virtual const cuda::tensor_descriptor& get_cudnn_tensor_descriptor ( ) const { return *cudnn_descriptor; } #endif private: virtual size_t get_alias_offset() const { return data_offset; } #ifdef DLIB_USE_CUDA std::shared_ptr cudnn_descriptor; #endif gpu_data* data_instance; any* _annotation; size_t data_offset; virtual gpu_data& data() { return *data_instance; } virtual const gpu_data& data() const { return *data_instance; } }; // ---------------------------------------------------------------------------------------- class alias_tensor_const_instance { public: const tensor& get() const { return inst; } operator const tensor& () { return inst; } alias_tensor_const_instance(const alias_tensor_instance& item) : inst(item) {} private: alias_tensor_instance inst; friend class alias_tensor; alias_tensor_const_instance() {} }; // ---------------------------------------------------------------------------------------- class alias_tensor { public: alias_tensor ( ) {} alias_tensor ( long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1 ) { DLIB_ASSERT( n_ >= 0 && k_ >= 0 && nr_ >= 0 && nc_ >= 0); inst.m_n = n_; inst.m_k = k_; inst.m_nr = nr_; inst.m_nc = nc_; inst.m_size = n_*k_*nr_*nc_; } long long num_samples( ) const { return inst.m_n; } long long k( ) const { return inst.m_k; } long long nr( ) const { return inst.m_nr; } long long nc( ) const { return inst.m_nc; } size_t size( ) const { return inst.m_size; } alias_tensor_instance operator() ( tensor& t, size_t offset = 0 ) const { DLIB_CASSERT(offset+size() <= t.size(), "offset: "<(); inst.cudnn_descriptor->set_size(inst.m_n, inst.m_k, inst.m_nr, inst.m_nc); } #endif inst.data_instance = &t.data(); inst._annotation = &t.annotation(); // Note that t might already be an aliasing tensor so we need to take that into // account. inst.data_offset = t.get_alias_offset()+offset; return inst; } alias_tensor_const_instance operator() ( const tensor& t, size_t offset = 0 ) const { alias_tensor_const_instance temp; temp.inst = (*this)(const_cast(t),offset); return temp; } private: mutable alias_tensor_instance inst; }; inline void serialize(const alias_tensor& item, std::ostream& out) { int version = 1; serialize(version, out); serialize(item.num_samples(), out); serialize(item.k(), out); serialize(item.nr(), out); serialize(item.nc(), out); } inline void deserialize(alias_tensor& item, std::istream& in) { int version = 0; deserialize(version, in); if (version != 1) throw serialization_error("Unexpected version found while deserializing dlib::alias_tensor."); long long num_samples, k, nr, nc; deserialize(num_samples, in); deserialize(k, in); deserialize(nr, in); deserialize(nc, in); item = alias_tensor(num_samples, k, nr, nc); } // ---------------------------------------------------------------------------------------- inline void memcpy ( alias_tensor_instance&& dest, const tensor& src ) { memcpy(static_cast(dest), src); } } #endif // DLIB_DNn_TENSOR_H_ ================================================ FILE: dlib/cuda/tensor_abstract.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_DNn_TENSOR_ABSTRACT_H_ #ifdef DLIB_DNn_TENSOR_ABSTRACT_H_ #include "../matrix.h" #include "../any/any_abstract.h" namespace dlib { // ---------------------------------------------------------------------------------------- class tensor { /*! WHAT THIS OBJECT REPRESENTS This object represents a 4D array of float values, all stored contiguously in memory. Importantly, it keeps two copies of the floats, one on the host CPU side and another on the GPU device side. It automatically performs the necessary host/device transfers to keep these two copies of the data in sync. All transfers to the device happen asynchronously with respect to the default CUDA stream so that CUDA kernel computations can overlap with data transfers. However, any transfers from the device to the host happen synchronously in the default CUDA stream. Therefore, you should perform all your CUDA kernel launches on the default stream so that transfers back to the host do not happen before the relevant computations have completed. If DLIB_USE_CUDA is not #defined then this object will not use CUDA at all. Instead, it will simply store one host side memory block of floats. Finally, the convention in dlib code is to interpret the tensor as a set of num_samples() 3D arrays, each of dimension k() by nr() by nc(). Also, while this class does not specify a memory layout, the convention is to assume that indexing into an element at coordinates (sample,k,r,c) can be accomplished via: host()[((sample*t.k() + k)*t.nr() + r)*t.nc() + c] THREAD SAFETY Instances of this object are not thread-safe. So don't touch one from multiple threads at the same time. !*/ public: virtual ~tensor(); long long num_samples( ) const; /*! ensures - returns the number of 3D arrays of dimension k() by nr() by nc() there are in this object. !*/ long long k( ) const; /*! ensures - returns the k dimension of this tensor. Generally, we think of a tensor as containing num_samples() images of nr() by nc() rows and columns, each with k() channels. !*/ long long nr( ) const; /*! ensures - returns the number of rows in this tensor. !*/ long long nc( ) const; /*! ensures - returns the number of columns in this tensor. !*/ size_t size( ) const; /*! ensures - returns num_samples()*k()*nr()*nc() (i.e. the total number of floats in this tensor) !*/ void async_copy_to_device( ) const; /*! ensures - This function does not block. - if (the host version of the data is newer than the device's copy) then - Begins asynchronously copying host data to the device. - A call to device() that happens before the transfer completes will block until the transfer is complete. That is, it is safe to call async_copy_to_device() and then immediately call device(). !*/ typedef float* iterator; typedef const float* const_iterator; iterator begin() { return host(); } const_iterator begin() const { return host(); } iterator end() { return host()+size(); } const_iterator end() const { return host()+size(); } /*! ensures - makes a tensor iterable just like the STL containers. !*/ virtual const float* host( ) const = 0; /*! ensures - returns a pointer to the host memory block of size() contiguous float values or nullptr if size()==0. - if (the host's copy of the data is out of date) then - copies the data from the device to the host, while this is happening the call to host() blocks. !*/ virtual float* host( ) = 0; /*! ensures - returns a pointer to the host memory block of size() contiguous float values or nullptr if size()==0. - if (the host's copy of the data is out of date) then - copies the data from the device to the host, while this is happening the call to host() blocks. - Marks the device side data as out of date so that the next call to device() will perform a host to device transfer. If you want to begin the transfer immediately then you can call async_copy_to_device() after calling host(). !*/ virtual float* host_write_only( ) = 0; /*! ensures - This function returns the same pointer as host(), except that it never performs a device to host memory copy. Instead, it immediately marks the device side data as out of date, effectively discarding it. Therefore, the values in the data pointed to by host_write_only() are undefined and you should only call host_write_only() if you are going to assign to every memory location in the returned memory block. !*/ virtual const float* device( ) const = 0; /*! requires - DLIB_USE_CUDA is #defined ensures - returns a pointer to the device memory block of size() contiguous float values or nullptr if size()==0. - if (the device's copy of the data is out of date) then - copies the data from the host to the device, while this is happening the call to device() blocks. !*/ virtual float* device( ) = 0; /*! requires - DLIB_USE_CUDA is #defined ensures - returns a pointer to the device memory block of size() contiguous float values or nullptr if size()==0. - if (the device's copy of the data is out of date) then - copies the data from the host to the device, while this is happening the call to device() blocks. - Marks the host side data as out of date so that the next call to host() will perform a device to host transfer. !*/ virtual float* device_write_only( ) = 0; /*! requires - DLIB_USE_CUDA is #defined ensures - This function returns the same pointer as device(), except that it never performs a host to device memory copy. Instead, it immediately marks the host side data as out of date, effectively discarding it. Therefore, the values in the data pointed to by device_write_only() are undefined and you should only call device_write_only() if you are going to assign to every memory location in the returned memory block. !*/ virtual const any& annotation( ) const = 0; /*! ensures - returns a const reference to the any object in this tensor. The any object can be used to store any additional annotation you like in a tensor. However, it should be noted that the annotation() is ignored by serialize() and therefore not saved when a tensor is serialized. !*/ virtual any& annotation( ) = 0; /*! ensures - returns a non-const reference to the any object in this tensor. The any object can be used to store any additional annotation you like in a tensor. However, it should be noted that the annotation() is ignored by serialize() and therefore not saved when a tensor is serialized. !*/ int device_id( ) const; /*! ensures - returns the ID of the CUDA device that allocated this memory. I.e. the number returned by cudaGetDevice() when the memory was allocated. - If CUDA is not being used then this function always returns 0. !*/ tensor& operator= ( float val ); /*! ensures - sets all elements of this tensor equal to val. - returns *this !*/ tensor& operator*= ( float val ); /*! ensures - pointwise multiplies all elements of *this tensor with val. - returns *this !*/ tensor& operator/= ( float val ); /*! ensures - pointwise divides all elements of *this tensor with val. - returns *this !*/ template tensor& operator= ( const matrix_exp& item ); /*! requires - num_samples() == item.nr() - k()*nr()*nc() == item.nc() - item contains float values ensures - Assigns item to *this tensor by performing: set_ptrm(host(), num_samples(), k()*nr()*nc()) = item; !*/ template tensor& operator+= ( const matrix_exp& item ); /*! requires - num_samples() == item.nr() - k()*nr()*nc() == item.nc() - item contains float values ensures - Adds item to *this tensor by performing: set_ptrm(host(), num_samples(), k()*nr()*nc()) += item; !*/ template tensor& operator-= ( const matrix_exp& item ); /*! requires - num_samples() == item.nr() - k()*nr()*nc() == item.nc() - item contains float values ensures - Subtracts item from *this tensor by performing: set_ptrm(host(), num_samples(), k()*nr()*nc()) -= item; !*/ template void set_sample ( unsigned long long idx, const matrix_exp& item ); /*! requires - idx < num_samples() - k()*nr()*nc() == item.size() - item contains float values ensures - Assigns item to the idx'th sample in *this by performing: set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) = item; !*/ template void add_to_sample ( unsigned long long idx, const matrix_exp& item ); /*! requires - idx < num_samples() - k()*nr()*nc() == item.size() - item contains float values ensures - Adds item to the idx'th sample in *this by performing: set_ptrm(host()+idx*item.size(), item.nr(), item.nc()) += item; !*/ protected: // You can't move or copy another tensor into *this since that might modify the // tensor's dimensions. If you want to do that sort of thing then use a // resizable_tensor. tensor(const tensor& item); tensor& operator= (const tensor& item); tensor(tensor&& item); tensor& operator=(tensor&& item); }; // ---------------------------------------------------------------------------------------- void memcpy ( tensor& dest, const tensor& src ); /*! requires - dest.size() == src.size() ensures - Copies the data in src to dest. If the device data is current on both src and dest then the copy will happen entirely on the device side. - It doesn't matter what GPU device is selected by cudaSetDevice(). You can always copy tensor objects to and from each other regardless. - This function blocks until the copy has completed. !*/ // ---------------------------------------------------------------------------------------- bool is_vector ( const tensor& t ); /*! ensures - returns true if and only if one of the following is true: - t.size() == t.num_samples() - t.size() == t.k() - t.size() == t.nr() - t.size() == t.nc() !*/ // ---------------------------------------------------------------------------------------- const matrix_exp mat ( const tensor& t, long long nr, long long nc ); /*! requires - nr >= 0 - nc >= 0 - nr*nc == t.size() ensures - returns a matrix M such that: - M.nr() == nr - m.nc() == nc - for all valid r and c: M(r,c) == t.host()[r*nc + c] (i.e. the tensor is interpreted as a matrix laid out in memory in row major order) !*/ const matrix_exp mat ( const tensor& t ); /*! ensures - if (t.size() != 0) then - returns mat(t, t.num_samples(), t.size()/t.num_samples()) - else - returns an empty matrix. !*/ const matrix_exp image_plane ( const tensor& t, long long sample = 0, long long k = 0 ); /*! requires - t.size() != 0 - 0 <= sample < t.num_samples() - 0 <= k < t.k() ensures - returns the k-th image plane from the sample-th image in t. That is, returns a matrix M such that: - M contains float valued elements. - M.nr() == t.nr() - M.nc() == t.nc() - for all valid r and c: - M(r,c) == t.host()[((sample*t.k() + k)*t.nr() + r)*t.nc() + c] !*/ // ---------------------------------------------------------------------------------------- bool have_same_dimensions ( const tensor& a, const tensor& b ); /*! ensures - returns true if and only if all of the fallowing are satisfied: - a.num_samples() == b.num_samples() - a.k() == b.k() - a.nr() == b.nr() - a.nc() == b.nc() !*/ // ---------------------------------------------------------------------------------------- class resizable_tensor : public tensor { /*! WHAT THIS OBJECT REPRESENTS This object is just a tensor with the additional ability to be resized. !*/ public: resizable_tensor( ); /*! ensures - #size() == 0 - #num_samples() == 0 - #k() == 0 - #nr() == 0 - #nc() == 0 - #capacity() == 0 !*/ template resizable_tensor( const matrix_exp& item ); /*! requires - item contains float values ensures - #num_samples() == item.nr() - #k() == item.nc() - #nr() == 1 - #nc() == 1 - Assigns item to *this tensor by performing: set_ptrm(host(), num_samples(), k()*nr()*nc()) = item; - #capacity() == size() !*/ explicit resizable_tensor( long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1 ); /*! requires - n_ >= 0 - k_ >= 0 - nr_ >= 0 - nc_ >= 0 ensures - #size() == n_*k_*nr_*nc_ - #num_samples() == n_ - #k() == k_ - #nr() == nr_ - #nc() == nc_ - #capacity() == size() !*/ // This object is copyable and movable resizable_tensor(const resizable_tensor&) = default; resizable_tensor(resizable_tensor&&) = default; resizable_tensor& operator= (const resizable_tensor&) = default; resizable_tensor& operator= (resizable_tensor&&) = default; size_t capacity ( ) const; /*! ensures - returns the total number of floats allocated. This might be different from the size() since calls to set_size() that make a tensor smaller don't trigger reallocations. They simply adjust the nominal dimensions while keeping the same allocated memory block. This makes calls to set_size() very fast. If you need to deallocate a tensor then use clear(). !*/ void clear( ); /*! ensures - #size() == 0 - #num_samples() == 0 - #k() == 0 - #nr() == 0 - #nc() == 0 - #annotation().is_empty() == true - #capacity() == 0 !*/ void copy_size ( const tensor& item ); /*! ensures - resizes *this so that: have_same_dimensions(#*this, item)==true !*/ void set_size( long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1 ); /*! requires - n_ >= 0 - k_ >= 0 - nr_ >= 0 - nc_ >= 0 ensures - #size() == n_*k_*nr_*nc_ - #num_samples() == n_ - #k() == k_ - #nr() == nr_ - #nc() == nc_ - #capacity() == max(#size(), capacity()) (i.e. capacity() never goes down when calling set_size().) !*/ template resizable_tensor& operator= ( const matrix_exp& item ); /*! requires - item contains float values ensures - if (num_samples() == item.nr() && k()*nr()*nc() == item.nc()) then - the dimensions of this tensor are not changed - else - #num_samples() == item.nr() - #k() == item.nc() - #nr() == 1 - #nc() == 1 - Assigns item to *this tensor by performing: set_ptrm(host(), num_samples(), k()*nr()*nc()) = item; !*/ }; void serialize(const tensor& item, std::ostream& out); void deserialize(resizable_tensor& item, std::istream& in); /*! provides serialization support for tensor and resizable_tensor. Note that you can serialize to/from any combination of tenor and resizable_tensor objects. !*/ // ---------------------------------------------------------------------------------------- double dot( const tensor& a, const tensor& b ); /*! requires - a.size() == b.size() ensures - returns the dot product between a and b when they are both treated as a.size() dimensional vectors. That is, this function pointwise multiplies the vectors together, then sums the result and returns it. !*/ // ---------------------------------------------------------------------------------------- class alias_tensor_instance : public tensor { /*! WHAT THIS OBJECT REPRESENTS This object is a tensor that aliases another tensor. That is, it doesn't have its own block of memory but instead simply holds pointers to the memory of another tensor object. It therefore allows you to efficiently break a tensor into pieces and pass those pieces into functions. An alias_tensor_instance doesn't own the resources it points to in any sense. So it is important to make sure that the underlying owning tensor doesn't get destructed before any alias tensors which point to it are destructed. !*/ // You can't default initialize this object. You can only get instances of it from // alias_tensor::operator(). alias_tensor_instance( ); }; inline void memcpy ( alias_tensor_instance&& dest, const tensor& src ) { memcpy(static_cast(dest), src); } /*! A convenient overload for copying from src to dest when you have a temporary alias tensor. !*/ class alias_tensor_const_instance { /*! WHAT THIS OBJECT REPRESENTS This is essentially a const version of alias_tensor_instance and therefore represents a tensor. However, due to the mechanics of C++, this object can't inherit from tensor. So instead it provides a get() and an implicit conversion to const tensor. !*/ public: // non-const alias tensors are convertible to const ones. alias_tensor_const_instance(const alias_tensor_instance& item); // Methods that cast the alias to a tensor. const tensor& get() const; operator const tensor& (); private: // You can't default initialize this object. You can only get instances of it from // alias_tensor::operator(). alias_tensor_const_instance(); }; class alias_tensor { /*! WHAT THIS OBJECT REPRESENTS This is a tool for creating tensor objects that alias other tensor objects. That is, it allows you to make a tensor that references the memory space of another tensor object rather than owning its own memory. This allows you to do things like interpret a single tensor in different ways or even as a group of multiple tensors. !*/ public: alias_tensor ( ); /*! ensures - #size() == 0 - #num_samples() == 0 - #k() == 0 - #nr() == 0 - #nc() == 0 !*/ alias_tensor ( long long n_, long long k_ = 1, long long nr_ = 1, long long nc_ = 1 ); /*! requires - n_ >= 0 - k_ >= 0 - nr_ >= 0 - nc_ >= 0 ensures - #size() == n_*k_*nr_*nc_ - #num_samples() == n_ - #k() == k_ - #nr() == nr_ - #nc() == nc_ !*/ long long num_samples() const; long long k() const; long long nr() const; long long nc() const; size_t size() const; alias_tensor_instance operator() ( tensor& t, size_t offset = 0 ) const; /*! requires - offset+size() <= t.size() ensures - Returns a tensor that simply aliases the elements of t beginning with t's offset'th element. Specifically, this function returns an aliasing tensor T such that: - T.size() == size() - T.num_samples() == num_samples() - T.k() == k() - T.nr() == nr() - T.nc() == nc() - T.host() == t.host()+offset - T.device() == t.device()+offset - &T.annotation() == &t.annotation() !*/ alias_tensor_const_instance operator() ( const tensor& t, size_t offset = 0 ) const; /*! requires - offset+size() <= t.size() ensures - This function is identical to the above version of operator() except that it takes and returns const tensors instead of non-const tensors. !*/ }; void serialize(const alias_tensor& item, std::ostream& out); void deserialize(alias_tensor& item, std::istream& in); /*! provides serialization support for alias_tensor. !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_DNn_TENSOR_ABSTRACT_H_ ================================================ FILE: dlib/cuda/tensor_tools.cpp ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_TeNSOR_TOOLS_CPP_ #define DLIB_TeNSOR_TOOLS_CPP_ #include "tensor_tools.h" #include "../string.h" #include namespace dlib { namespace { std::atomic& dnn_prefer_fastest_algo ( ) { static std::atomic var(true); return var; } } bool dnn_prefer_fastest_algorithms ( ) { return dnn_prefer_fastest_algo(); } void set_dnn_prefer_fastest_algorithms( ) { dnn_prefer_fastest_algo() = true; } void set_dnn_prefer_smallest_algorithms( ) { dnn_prefer_fastest_algo() = false; } } namespace dlib { namespace tt { // ---------------------------------------------------------------------------------------- void inverse_norms ( resizable_tensor& invnorms, const tensor& data, const double eps ) { #ifdef DLIB_USE_CUDA cuda::inverse_norms(invnorms, data, eps); #else invnorms = reciprocal(sqrt(sum_cols(squared(mat(data))) + eps)); #endif } void dot_prods ( resizable_tensor& out, const tensor& lhs, const tensor& rhs ) { #ifdef DLIB_USE_CUDA cuda::dot_prods(out, lhs, rhs); #else out = sum_cols(pointwise_multiply(mat(lhs), mat(rhs))); #endif } void dot_prods ( bool add_to, tensor& out, const tensor& lhs, const tensor& rhs ) { #ifdef DLIB_USE_CUDA cuda::dot_prods(add_to, out, lhs, rhs); #else if (add_to) out += sum_cols(pointwise_multiply(mat(lhs), mat(rhs))); else out = sum_cols(pointwise_multiply(mat(lhs), mat(rhs))); #endif } void scale_columns ( tensor& out, const tensor& m, const tensor& v ) { DLIB_CASSERT(have_same_dimensions(out,m)); DLIB_CASSERT(is_vector(v)); if (m.size() == 0 && v.size() == 0) return; DLIB_CASSERT(m.size() != 0); DLIB_CASSERT(m.size()/m.num_samples() == v.size()); #ifdef DLIB_USE_CUDA cuda::scale_columns(out, m, v); #else out = scale_columns(mat(m), mat(v)); #endif } void scale_rows ( tensor& out, const tensor& m, const tensor& v ) { DLIB_CASSERT(have_same_dimensions(out,m)); DLIB_CASSERT(is_vector(v)); if (m.size() == 0 && v.size() == 0) return; DLIB_CASSERT(m.size() != 0); DLIB_CASSERT(m.num_samples() == static_cast(v.size())); #ifdef DLIB_USE_CUDA cuda::scale_rows(out, m, v); #else out = scale_rows(mat(m), mat(v)); #endif } void scale_rows2 ( float beta, tensor& out, const tensor& m1, const tensor& m2, const tensor& v1, const tensor& v2 ) { DLIB_CASSERT(have_same_dimensions(out,m1)); DLIB_CASSERT(have_same_dimensions(out,m2)); DLIB_CASSERT(have_same_dimensions(v1,v2)); DLIB_CASSERT(is_vector(mat(v1))); DLIB_CASSERT(static_cast(v1.size()) == m1.num_samples()); #ifdef DLIB_USE_CUDA cuda::scale_rows2(beta, out, m1, m2, v1, v2); #else if (beta == 0) out = scale_rows(mat(m1) - scale_rows(mat(m2),mat(v1)), mat(v2)); else out = beta*mat(out) + scale_rows(mat(m1) - scale_rows(mat(m2),mat(v1)), mat(v2)); #endif } // ---------------------------------------------------------------------------------------- void exp ( tensor& dest, const tensor& src ) { DLIB_CASSERT(dest.size() == src.size()); #ifdef DLIB_USE_CUDA cuda::exp(dest,src); #else dest = exp(mat(src)); #endif } // ---------------------------------------------------------------------------------------- void log ( tensor& dest, const tensor& src ) { DLIB_CASSERT(dest.size() == src.size()); #ifdef DLIB_USE_CUDA cuda::log(dest,src); #else dest = log(mat(src)); #endif } // ---------------------------------------------------------------------------------------- void log10 ( tensor& dest, const tensor& src ) { DLIB_CASSERT(dest.size() == src.size()); #ifdef DLIB_USE_CUDA cuda::log10(dest,src); #else dest = log10(mat(src)); #endif } // ---------------------------------------------------------------------------------------- void gemm ( float beta, tensor& dest, float alpha, const tensor& lhs, bool trans_lhs, const tensor& rhs, bool trans_rhs, operation_mode mode ) { #ifdef DLIB_USE_CUDA cuda::gemm(beta, dest, alpha, lhs, trans_lhs, rhs, trans_rhs, mode); #else if (mode == operation_mode::CHANNEL_WISE) { if (beta != 0) { if (trans_lhs && trans_rhs) dest = alpha * trans(mat(lhs)) * trans(mat(rhs)) + beta * mat(dest); else if (!trans_lhs && trans_rhs) dest = alpha * mat(lhs) * trans(mat(rhs)) + beta * mat(dest); else if (trans_lhs && !trans_rhs) dest = alpha * trans(mat(lhs)) * mat(rhs) + beta * mat(dest); else dest = alpha * mat(lhs) * mat(rhs) + beta * mat(dest); } else { if (trans_lhs && trans_rhs) dest = alpha * trans(mat(lhs)) * trans(mat(rhs)); else if (!trans_lhs && trans_rhs) dest = alpha * mat(lhs) * trans(mat(rhs)); else if (trans_lhs && !trans_rhs) dest = alpha * trans(mat(lhs)) * mat(rhs); else dest = alpha * mat(lhs) * mat(rhs); } } else if (mode == operation_mode::PLANE_WISE) { auto is_matrix = [](const auto& tensor) { return ((tensor.num_samples() * tensor.k() == 1 && tensor.nr() * tensor.nc() > 1) || (tensor.num_samples() * tensor.k() > 1 && tensor.nr() * tensor.nc() == 1)); }; long num_samples = std::min({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() }); long num_channels = std::min({ lhs.k(), rhs.k(), dest.k() }); const bool lhs_is_matrix = is_matrix(lhs), rhs_is_matrix = is_matrix(rhs), dest_is_matrix = is_matrix(dest); if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix) { num_samples = num_channels = 1; } long lhs_rows = (lhs_is_matrix && lhs.num_samples() > 1) ? lhs.num_samples() : lhs.nr(); long lhs_cols = (lhs_is_matrix && lhs.k() > 1) ? lhs.k() : lhs.nc(); long rhs_rows = (rhs_is_matrix && rhs.num_samples() > 1) ? rhs.num_samples() : rhs.nr(); long rhs_cols = (rhs_is_matrix && rhs.k() > 1) ? rhs.k() : rhs.nc(); long dest_rows = (dest_is_matrix && dest.num_samples() > 1) ? dest.num_samples() : dest.nr(); long dest_cols = (dest_is_matrix && dest.k() > 1) ? dest.k() : dest.nc(); const size_t lhs_plane_size = lhs_rows * lhs_cols; const size_t rhs_plane_size = rhs_rows * rhs_cols; const size_t dest_plane_size = dest_rows * dest_cols; for (long b = 0; b < num_samples; ++b) { for (long c = 0; c < num_channels; ++c) { auto lhs_slice = lhs_is_matrix ? alias_tensor(lhs_rows, lhs_cols)(lhs, 0) : alias_tensor(lhs_rows, lhs_cols)(lhs, (b * num_channels + c) * lhs_plane_size); auto rhs_slice = rhs_is_matrix ? alias_tensor(rhs_rows, rhs_cols)(rhs, 0) : alias_tensor(rhs_rows, rhs_cols)(rhs, (b * num_channels + c) * rhs_plane_size); auto dest_slice = dest_is_matrix ? alias_tensor(dest_rows, dest_cols)(dest, 0) : alias_tensor(dest_rows, dest_cols)(dest, (b * num_channels + c) * dest_plane_size); if (beta != 0) { if (trans_lhs && trans_rhs) dest_slice = alpha * trans(mat(lhs_slice)) * trans(mat(rhs_slice)) + beta * mat(dest_slice); else if (!trans_lhs && trans_rhs) dest_slice = alpha * mat(lhs_slice) * trans(mat(rhs_slice)) + beta * mat(dest_slice); else if (trans_lhs && !trans_rhs) dest_slice = alpha * trans(mat(lhs_slice)) * mat(rhs_slice) + beta * mat(dest_slice); else dest_slice = alpha * mat(lhs_slice) * mat(rhs_slice) + beta * mat(dest_slice); } else { if (trans_lhs && trans_rhs) dest_slice = alpha * trans(mat(lhs_slice)) * trans(mat(rhs_slice)); else if (!trans_lhs && trans_rhs) dest_slice = alpha * mat(lhs_slice) * trans(mat(rhs_slice)); else if (trans_lhs && !trans_rhs) dest_slice = alpha * trans(mat(lhs_slice)) * mat(rhs_slice); else dest_slice = alpha * mat(lhs_slice) * mat(rhs_slice); } } } } #endif } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- tensor_rand:: tensor_rand( unsigned long long seed ) #ifdef DLIB_USE_CUDA :rnd(seed){} #else {rnd.set_seed(cast_to_string(seed)); } #endif void tensor_rand:: fill_gaussian ( tensor& data, float mean, float stddev ) { DLIB_CASSERT(data.size()%2 == 0); #ifdef DLIB_USE_CUDA rnd.fill_gaussian(data, mean, stddev); #else for (auto& x : data) x = rnd.get_random_gaussian()*stddev + mean; #endif } void tensor_rand:: fill_uniform ( tensor& data ) { #ifdef DLIB_USE_CUDA rnd.fill_uniform(data); #else for (auto& x : data) x = rnd.get_random_float(); #endif } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- void multiply ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ) { DLIB_CASSERT(dest.k() == src1.k() && src1.k() == src2.k() && dest.nr() == src1.nr() && src1.nr() == src2.nr() && dest.nc() == src1.nc() && src1.nc() == src2.nc() ); const long MD = std::max(std::max(dest.num_samples(),src1.num_samples()),src2.num_samples()); DLIB_CASSERT((dest.num_samples()==1 || dest.num_samples()==MD) && (src1.num_samples()==1 || src1.num_samples()==MD) && (src2.num_samples()==1 || src2.num_samples()==MD) ); #ifdef DLIB_USE_CUDA cuda::multiply(add_to, dest, src1, src2); #else cpu::multiply(add_to, dest, src1, src2); #endif } void scale_channels ( bool add_to, tensor& dest, const tensor& src, const tensor& scales ) { #ifdef DLIB_USE_CUDA cuda::scale_channels(add_to, dest, src, scales); #else cpu::scale_channels(add_to, dest, src, scales); #endif } void multiply_conv ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ) { #ifdef DLIB_USE_CUDA cuda::multiply_conv(add_to, dest, src1, src2); #else cpu::multiply_conv(add_to, dest, src1, src2); #endif } void multiply_zero_padded ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ) { #ifdef DLIB_USE_CUDA cuda::multiply_zero_padded(add_to, dest, src1, src2); #else cpu::multiply_zero_padded(add_to, dest, src1, src2); #endif } // ---------------------------------------------------------------------------------------- void affine_transform( tensor& dest, const tensor& src, const float A, const float B ) { #ifdef DLIB_USE_CUDA cuda::affine_transform(dest,src,A,B); #else cpu::affine_transform(dest,src,A,B); #endif } void affine_transform( tensor& dest, const tensor& src, const float A ) { #ifdef DLIB_USE_CUDA cuda::affine_transform(dest,src,A); #else cpu::affine_transform(dest,src,A,0); #endif } void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const float A, const float B, const float C ) { #ifdef DLIB_USE_CUDA cuda::affine_transform(dest,src1,src2,A,B,C); #else cpu::affine_transform(dest,src1,src2,A,B,C); #endif } void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const float A, const float B ) { #ifdef DLIB_USE_CUDA cuda::affine_transform(dest,src1,src2,A,B); #else cpu::affine_transform(dest,src1,src2,A,B,0); #endif } void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C, const float D ) { #ifdef DLIB_USE_CUDA cuda::affine_transform(dest,src1,src2,src3,A,B,C,D); #else cpu::affine_transform(dest,src1,src2,src3,A,B,C,D); #endif } void affine_transform_range( size_t begin, size_t end, tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C ) { #ifdef DLIB_USE_CUDA cuda::affine_transform_range(begin, end, dest,src1,src2,src3,A,B,C); #else cpu::affine_transform_range(begin, end, dest,src1,src2,src3,A,B,C); #endif } void affine_transform( const rectangle& rect, tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, float A, float B, float C ) { #ifdef DLIB_USE_CUDA cuda::affine_transform(rect, dest,src1,src2,src3,A,B,C); #else cpu::affine_transform(rect, dest,src1,src2,src3,A,B,C); #endif } void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C ) { #ifdef DLIB_USE_CUDA cuda::affine_transform_range(0,dest.size(),dest,src1,src2,src3,A,B,C); #else cpu::affine_transform_range(0,dest.size(),dest,src1,src2,src3,A,B,C); #endif } // ---------------------------------------------------------------------------------------- void affine_transform( tensor& dest, const tensor& src, const tensor& A, const tensor& B ) { #ifdef DLIB_USE_CUDA cuda::affine_transform(dest,src,A,B); #else cpu::affine_transform(dest,src,A,B); #endif } // ---------------------------------------------------------------------------------------- void affine_transform_conv( tensor& dest, const tensor& src, const tensor& A, const tensor& B ) { #ifdef DLIB_USE_CUDA cuda::affine_transform_conv(dest,src,A,B); #else cpu::affine_transform_conv(dest,src,A,B); #endif } // ---------------------------------------------------------------------------------------- void compute_adam_update ( size_t begin, size_t end, tensor& s, tensor& m, tensor& v, const float t, const float learning_rate, const float weight_decay, const float momentum1, const float momentum2, const tensor& params, const tensor& params_grad ) { #ifdef DLIB_USE_CUDA cuda::compute_adam_update(begin, end, s, m, v, t, learning_rate, weight_decay, momentum1, momentum2, params, params_grad); #else cpu::compute_adam_update(begin, end, s, m, v, t, learning_rate, weight_decay, momentum1, momentum2, params, params_grad); #endif } // ---------------------------------------------------------------------------------------- void batch_normalize_inference ( const double eps, resizable_tensor& dest, const tensor& src, const tensor& gamma, const tensor& beta, const tensor& running_means, const tensor& running_variances ) { #ifdef DLIB_USE_CUDA cuda::batch_normalize_inference(eps,dest,src,gamma,beta,running_means,running_variances); #else cpu::batch_normalize_inference(eps,dest,src,gamma,beta,running_means,running_variances); #endif } void batch_normalize ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& vars, const double averaging_factor, resizable_tensor& running_means, resizable_tensor& running_variances, const tensor& src, const tensor& gamma, const tensor& beta ) { #ifdef DLIB_USE_CUDA cuda::batch_normalize(eps,dest,means,vars,averaging_factor,running_means,running_variances,src,gamma,beta); #else cpu::batch_normalize(eps,dest,means,vars,averaging_factor,running_means,running_variances,src,gamma,beta); #endif } void batch_normalize_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad ) { #ifdef DLIB_USE_CUDA cuda::batch_normalize_gradient(eps,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad); #else cpu::batch_normalize_gradient(eps,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad); #endif } // ---------------------------------------------------------------------------------------- void batch_normalize_conv_inference ( const double eps, resizable_tensor& dest, const tensor& src, const tensor& gamma, const tensor& beta, const tensor& running_means, const tensor& running_variances ) { #ifdef DLIB_USE_CUDA cuda::batch_normalize_conv_inference(eps,dest,src,gamma,beta,running_means,running_variances); #else cpu::batch_normalize_conv_inference(eps,dest,src,gamma,beta,running_means,running_variances); #endif } void batch_normalize_conv ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& vars, const double averaging_factor, resizable_tensor& running_means, resizable_tensor& running_variances, const tensor& src, const tensor& gamma, const tensor& beta ) { #ifdef DLIB_USE_CUDA cuda::batch_normalize_conv(eps,dest,means,vars,averaging_factor,running_means,running_variances,src,gamma,beta); #else cpu::batch_normalize_conv(eps,dest,means,vars,averaging_factor,running_means,running_variances,src,gamma,beta); #endif } void batch_normalize_conv_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad ) { #ifdef DLIB_USE_CUDA cuda::batch_normalize_conv_gradient(eps,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad); #else cpu::batch_normalize_conv_gradient(eps,gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad); #endif } // ---------------------------------------------------------------------------------------- void layer_normalize ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& vars, const tensor& src, const tensor& gamma, const tensor& beta ) { #ifdef DLIB_USE_CUDA cuda::layer_normalize(eps, dest, means, vars, src, gamma, beta); #else cpu::layer_normalize(eps, dest, means, vars, src, gamma, beta); #endif } void layer_normalize_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad, resizable_tensor& dmeans, resizable_tensor& dvars ) { #ifdef DLIB_USE_CUDA cuda::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad, dmeans, dvars); #else cpu::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad, dmeans, dvars); #endif } // ---------------------------------------------------------------------------------------- void rms_normalize( const double eps, resizable_tensor& dest, resizable_tensor& scale, const tensor& src, const tensor& gamma ) { #ifdef DLIB_USE_CUDA cuda::rms_normalize(eps, dest, scale, src, gamma); #else cpu::rms_normalize(eps, dest, scale, src, gamma); #endif } void rms_normalize_gradient( const tensor& gradient_input, const tensor& scale, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, resizable_tensor& dscale ) { #ifdef DLIB_USE_CUDA cuda::rms_normalize_gradient(gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale); #else cpu::rms_normalize_gradient(gradient_input, scale, src, gamma, src_grad, gamma_grad, dscale); #endif } // ---------------------------------------------------------------------------------------- void threshold ( tensor& data, float thresh ) { #ifdef DLIB_USE_CUDA cuda::threshold(data,thresh); #else cpu::threshold(data,thresh); #endif } void dot ( const tensor& a, const tensor& b, tensor& result, size_t idx ) { #ifdef DLIB_USE_CUDA cuda::dot(a,b,result,idx); #else cpu::dot(a,b,result,idx); #endif } // ---------------------------------------------------------------------------------------- void add( float beta, tensor& dest, float alpha, const tensor& src ) { #ifdef DLIB_USE_CUDA cuda::add(beta,dest,alpha,src); #else cpu::add(beta,dest,alpha,src); #endif } // ---------------------------------------------------------------------------------------- void add ( tensor& dest, const tensor& src1, const tensor& src2 ) { #ifdef DLIB_USE_CUDA cuda::add(dest, src1, src2); #else cpu::add(dest, src1, src2); #endif } // ---------------------------------------------------------------------------------------- void assign_conv_bias_gradient ( tensor& grad, const tensor& gradient_input ) { #ifdef DLIB_USE_CUDA cuda::assign_conv_bias_gradient(grad,gradient_input); #else cpu::assign_conv_bias_gradient(grad,gradient_input); #endif } // ---------------------------------------------------------------------------------------- void assign_bias_gradient ( tensor& grad, const tensor& gradient_input ) { #ifdef DLIB_USE_CUDA cuda::assign_bias_gradient(grad,gradient_input); #else cpu::assign_bias_gradient(grad,gradient_input); #endif } // ---------------------------------------------------------------------------------------- void softmax( tensor& dest, const tensor& src, operation_mode mode ) { #ifdef DLIB_USE_CUDA cuda::softmax(dest, src, mode); #else cpu::softmax(dest, src, mode); #endif } void softmax_gradient( tensor& grad, const tensor& dest, const tensor& gradient_input, operation_mode mode ) { #ifdef DLIB_USE_CUDA cuda::softmax_gradient(grad, dest, gradient_input, mode); #else cpu::softmax_gradient(grad, dest, gradient_input, mode); #endif } // ---------------------------------------------------------------------------------------- void softmax_all ( tensor& dest, const tensor& src ) { #ifdef DLIB_USE_CUDA cuda::softmax_all(dest,src); #else cpu::softmax_all(dest,src); #endif } void softmax_all_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ) { #ifdef DLIB_USE_CUDA cuda::softmax_all_gradient(grad, dest, gradient_input); #else cpu::softmax_all_gradient(grad, dest, gradient_input); #endif } // ---------------------------------------------------------------------------------------- void sigmoid ( tensor& dest, const tensor& src ) { #ifdef DLIB_USE_CUDA cuda::sigmoid(dest,src); #else cpu::sigmoid(dest,src); #endif } void sigmoid_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ) { #ifdef DLIB_USE_CUDA cuda::sigmoid_gradient(grad, dest, gradient_input); #else cpu::sigmoid_gradient(grad, dest, gradient_input); #endif } // ---------------------------------------------------------------------------------------- void mish ( tensor& dest, const tensor& src ) { #ifdef DLIB_USE_CUDA cuda::mish(dest,src); #else cpu::mish(dest,src); #endif } void mish_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ) { #ifdef DLIB_USE_CUDA cuda::mish_gradient(grad, src, gradient_input); #else cpu::mish_gradient(grad, src, gradient_input); #endif } // ---------------------------------------------------------------------------------------- void relu ( tensor& dest, const tensor& src ) { #ifdef DLIB_USE_CUDA cuda::relu(dest,src); #else cpu::relu(dest,src); #endif } void relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ) { #ifdef DLIB_USE_CUDA cuda::relu_gradient(grad, dest, gradient_input); #else cpu::relu_gradient(grad, dest, gradient_input); #endif } // ---------------------------------------------------------------------------------------- void prelu ( tensor& dest, const tensor& src, const tensor& param ) { #ifdef DLIB_USE_CUDA cuda::prelu(dest, src, param); #else cpu::prelu(dest, src, param); #endif } void prelu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input, const tensor& param, tensor& params_grad ) { #ifdef DLIB_USE_CUDA cuda::prelu_gradient(grad, src, gradient_input, param, params_grad); #else cpu::prelu_gradient(grad, src, gradient_input, param, params_grad); #endif } // ---------------------------------------------------------------------------------------- void leaky_relu ( tensor& dest, const tensor& src, const float alpha ) { #ifdef DLIB_USE_CUDA cuda::leaky_relu(dest, src, alpha); #else cpu::leaky_relu(dest, src, alpha); #endif } void leaky_relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float alpha ) { #ifdef DLIB_USE_CUDA cuda::leaky_relu_gradient(grad, dest, gradient_input, alpha); #else cpu::leaky_relu_gradient(grad, dest, gradient_input, alpha); #endif } // ---------------------------------------------------------------------------------------- void tanh ( tensor& dest, const tensor& src ) { #ifdef DLIB_USE_CUDA cuda::tanh(dest,src); #else cpu::tanh(dest,src); #endif } void tanh_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ) { #ifdef DLIB_USE_CUDA cuda::tanh_gradient(grad, dest, gradient_input); #else cpu::tanh_gradient(grad, dest, gradient_input); #endif } // ---------------------------------------------------------------------------------------- void clipped_relu ( tensor& dest, const tensor& src, const float ceiling ) { #ifdef DLIB_USE_CUDA cuda::clipped_relu(dest, src, ceiling); #else cpu::clipped_relu(dest, src, ceiling); #endif } void clipped_relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float ceiling ) { #ifdef DLIB_USE_CUDA cuda::clipped_relu_gradient(grad, dest, gradient_input, ceiling); #else cpu::clipped_relu_gradient(grad, dest, gradient_input, ceiling); #endif } // ---------------------------------------------------------------------------------------- void elu ( tensor& dest, const tensor& src, const float alpha ) { #ifdef DLIB_USE_CUDA cuda::elu(dest, src, alpha); #else cpu::elu(dest, src, alpha); #endif } void elu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float alpha ) { #ifdef DLIB_USE_CUDA cuda::elu_gradient(grad, dest, gradient_input, alpha); #else cpu::elu_gradient(grad, dest, gradient_input, alpha); #endif } // ---------------------------------------------------------------------------------------- void gelu ( tensor& dest, const tensor& src ) { #ifdef DLIB_USE_CUDA cuda::gelu(dest,src); #else cpu::gelu(dest,src); #endif } void gelu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ) { #ifdef DLIB_USE_CUDA cuda::gelu_gradient(grad, src, gradient_input); #else cpu::gelu_gradient(grad, src, gradient_input); #endif } // ---------------------------------------------------------------------------------------- void smelu ( tensor& dest, const tensor& src, const float beta ) { DLIB_CASSERT(beta > 0); #ifdef DLIB_USE_CUDA cuda::smelu(dest, src, beta); #else cpu::smelu(dest, src, beta); #endif } void smelu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float beta ) { DLIB_CASSERT(beta > 0); #ifdef DLIB_USE_CUDA cuda::smelu_gradient(grad, dest, gradient_input, beta); #else cpu::smelu_gradient(grad, dest, gradient_input, beta); #endif } // ---------------------------------------------------------------------------------------- void silu ( tensor& dest, const tensor& src ) { #ifdef DLIB_USE_CUDA cuda::silu(dest,src); #else cpu::silu(dest,src); #endif } void silu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ) { #ifdef DLIB_USE_CUDA cuda::silu_gradient(grad, src, gradient_input); #else cpu::silu_gradient(grad, src, gradient_input); #endif } // ---------------------------------------------------------------------------------------- void resize_bilinear ( tensor& dest, long dest_row_stride, long dest_channel_stride, const tensor& src, long src_row_stride, long src_channel_stride ) { #ifdef DLIB_USE_CUDA cuda::resize_bilinear(dest,dest_row_stride,dest_channel_stride, src,src_row_stride,src_channel_stride); #else cpu::resize_bilinear(dest,dest_row_stride,dest_channel_stride, src,src_row_stride,src_channel_stride); #endif } void resize_bilinear_gradient ( tensor& grad, long grad_row_stride, long grad_channel_stride, const tensor& gradient_input, long gradient_input_row_stride, long gradient_input_channel_stride ) { #ifdef DLIB_USE_CUDA cuda::resize_bilinear_gradient(grad,grad_row_stride,grad_channel_stride, gradient_input,gradient_input_row_stride,gradient_input_channel_stride); #else cpu::resize_bilinear_gradient(grad,grad_row_stride,grad_channel_stride, gradient_input,gradient_input_row_stride,gradient_input_channel_stride); #endif } // ------------------------------------------------------------------------------------ void reorg ( bool add_to, tensor& dest, const int row_stride, const int col_stride, const tensor& src ) { #ifdef DLIB_USE_CUDA cuda::reorg(add_to, dest, row_stride, col_stride, src); #else cpu::reorg(add_to, dest, row_stride, col_stride, src); #endif } void reorg_gradient ( bool add_to, tensor& grad, const int row_stride, const int col_stride, const tensor& gradient_input ) { #ifdef DLIB_USE_CUDA cuda::reorg_gradient(add_to, grad, row_stride, col_stride, gradient_input); #else cpu::reorg_gradient(add_to, grad, row_stride, col_stride, gradient_input); #endif } // ------------------------------------------------------------------------------------ void copy_tensor( bool add_to, tensor& dest, size_t dest_k_offset, const tensor& src, size_t src_k_offset, size_t count_k ) { #ifdef DLIB_USE_CUDA cuda::copy_tensor(add_to, dest, dest_k_offset, src, src_k_offset, count_k); #else cpu::copy_tensor(add_to, dest, dest_k_offset, src, src_k_offset, count_k); #endif } // ---------------------------------------------------------------------------------------- void copy_tensor( bool add_to, tensor& dest, size_t dk, size_t dnr, size_t dnc, const tensor& src, size_t sk, size_t snr, size_t snc, size_t k, size_t nr, size_t nc ) { #ifdef DLIB_USE_CUDA cuda::copy_tensor(add_to, dest, dk, dnr, dnc , src, sk, snr, snc, k, nr, nc); #else cpu::copy_tensor(add_to, dest, dk, dnr, dnc, src, sk, snr, snc, k, nr, nc); #endif } // ---------------------------------------------------------------------------------------- void inv:: operator() ( const tensor& m, resizable_tensor& out ) { #ifdef DLIB_USE_CUDA finv(m,out); #else out = dlib::inv(mat(m)); #endif } // ---------------------------------------------------------------------------------------- void transpose( bool add_to, tensor& dest, const tensor& src ) { #ifdef DLIB_USE_CUDA cuda::transpose(add_to, dest, src); #else cpu::transpose(add_to, dest, src); #endif } // ---------------------------------------------------------------------------------------- void embeddings( resizable_tensor& dest, const tensor& src, const tensor& embs ) { #ifdef DLIB_USE_CUDA cuda::embeddings(dest, src, embs); #else cpu::embeddings(dest, src, embs); #endif } void embeddings_gradient( const tensor& prev, const tensor& gradient_input, tensor& grads, const tensor& freqs, float learning_rate, bool scale ) { #ifdef DLIB_USE_CUDA cuda::embeddings_gradient(prev, gradient_input, grads, freqs, learning_rate, scale); #else cpu::embeddings_gradient(prev, gradient_input, grads, freqs, learning_rate, scale); #endif } // ---------------------------------------------------------------------------------------- void compute_act_halt_probabilities( resizable_tensor& halt_probs, resizable_tensor& logits, const tensor& input_data, const tensor& halt_params, long batch_size, long seq_len, long feature_dim ) { #ifdef DLIB_USE_CUDA cuda::compute_act_halt_probabilities(halt_probs, logits, input_data, halt_params, batch_size, seq_len, feature_dim); #else cpu::compute_act_halt_probabilities(halt_probs, logits, input_data, halt_params, batch_size, seq_len, feature_dim); #endif } void update_act_state( resizable_tensor& output, const tensor& input_data, const tensor& halt_probs, resizable_tensor& cumulative_halting, resizable_tensor& remainders, resizable_tensor& n_steps, resizable_tensor& effective_weights, long batch_size, long seq_len, long d_model, long num_channels, float halt_threshold, long current_step ) { #ifdef DLIB_USE_CUDA cuda::update_act_state(output, input_data, halt_probs, cumulative_halting, remainders, n_steps, effective_weights, batch_size, seq_len, d_model, num_channels, halt_threshold, current_step); #else cpu::update_act_state(output, input_data, halt_probs, cumulative_halting, remainders, n_steps, effective_weights, batch_size, seq_len, d_model, num_channels, halt_threshold, current_step); #endif } void finalize_act_output( resizable_tensor& output, const tensor& input_data, const tensor& remainders, resizable_tensor& effective_weights, long batch_size, long seq_len, long d_model, long num_channels ) { #ifdef DLIB_USE_CUDA cuda::finalize_act_output(output, input_data, remainders, effective_weights, batch_size, seq_len, d_model, num_channels); #else cpu::finalize_act_output(output, input_data, remainders, effective_weights, batch_size, seq_len, d_model, num_channels); #endif } void apply_act_depth_scaling( tensor& gradients, const tensor& n_steps, long batch_size, long seq_len, long d_model, long num_channels, float max_steps, float scale_factor ) { #ifdef DLIB_USE_CUDA cuda::apply_act_depth_scaling(gradients, n_steps, batch_size, seq_len, d_model, num_channels, max_steps, scale_factor); #else cpu::apply_act_depth_scaling(gradients, n_steps, batch_size, seq_len, d_model, num_channels, max_steps, scale_factor); #endif } // ---------------------------------------------------------------------------------------- }} #endif // DLIB_TeNSOR_TOOLS_CPP_ ================================================ FILE: dlib/cuda/tensor_tools.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_TeNSOR_TOOLS_H_ #define DLIB_TeNSOR_TOOLS_H_ #include "tensor.h" #include "cudnn_dlibapi.h" #include "cublas_dlibapi.h" #include "cusolver_dlibapi.h" #include "curand_dlibapi.h" #include "cpu_dlib.h" #include "cuda_dlib.h" #include "../rand.h" #include #include "../geometry/rectangle.h" #include "../test_for_odr_violations.h" namespace dlib { bool dnn_prefer_fastest_algorithms(); void set_dnn_prefer_fastest_algorithms(); void set_dnn_prefer_smallest_algorithms(); } namespace dlib { namespace tt { // ---------------------------------------------------------------------------------------- void inverse_norms ( resizable_tensor& invnorms, const tensor& data, const double eps ); /*! ensures - #invnorms == reciprocal(sqrt(sum_cols(squared(mat(data))) + eps)) !*/ void dot_prods ( resizable_tensor& out, const tensor& lhs, const tensor& rhs ); /*! requires - have_same_dimensions(lhs,rhs) == true ensures - #out.num_samples() == lhs.num_samples() - #out.k() == #out.nr() == #out.nc() == 1 - #out == sum_cols(pointwise_multiply(mat(lhs), mat(rhs))); !*/ void dot_prods ( bool add_to, tensor& out, const tensor& lhs, const tensor& rhs ); /*! requires - have_same_dimensions(lhs,rhs) == true - out.size() == lhs.num_samples() - out.k() == out.nr() == out.nc() == 1 ensures - if (add_to) then - #out == mat(out) + sum_cols(pointwise_multiply(mat(lhs), mat(rhs))); - else - #out == sum_cols(pointwise_multiply(mat(lhs), mat(rhs))); !*/ void scale_columns ( tensor& out, const tensor& m, const tensor& v ); /*! requires - have_same_dimensions(out,m) == true - is_vector(v) == true - v.size() == mat(m).nc() ensures - performs: out = scale_columns(mat(m),mat(v)); !*/ void scale_rows ( tensor& out, const tensor& m, const tensor& v ); /*! requires - have_same_dimensions(out,m) == true - is_vector(v) == true - v.size() == m.num_samples() ensures - performs: out = scale_rows(mat(m),mat(v)); !*/ void scale_rows2 ( float beta, tensor& out, const tensor& m1, const tensor& m2, const tensor& v1, const tensor& v2 ); /*! requires - have_same_dimensions(out,m1) == true - have_same_dimensions(out,m2) == true - have_same_dimensions(v1,v2) == true - is_vector(v1) == true - v1.size() == m1.num_samples() ensures - performs: out = beta*out + scale_rows(mat(m1) - scale_rows(mat(m2),mat(v1)), mat(v2)); !*/ // ---------------------------------------------------------------------------------------- void exp ( tensor& dest, const tensor& src ); /*! requires - dest.size() == src.size() ensures - performs: dest = exp(mat(src)) !*/ // ---------------------------------------------------------------------------------------- void log ( tensor& dest, const tensor& src ); /*! requires - dest.size() == src.size() ensures - performs: dest = log(mat(src)) !*/ // ---------------------------------------------------------------------------------------- void log10 ( tensor& dest, const tensor& src ); /*! requires - dest.size() == src.size() ensures - performs: dest = log10(mat(src)) !*/ // ---------------------------------------------------------------------------------------- void gemm ( float beta, tensor& dest, float alpha, const tensor& lhs, bool trans_lhs, const tensor& rhs, bool trans_rhs, operation_mode mode = operation_mode::CHANNEL_WISE ); /*! requires - dest does not alias the memory of lhs or rhs - The dimensions of lhs and rhs must be compatible for matrix multiplication. The specific requirements depend on the mode: For CHANNEL_WISE mode (default): - Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs) - Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs) - Let D == mat(dest) - D.nr() == L.nr() && D.nc() == R.nc() (i.e. dest must be preallocated and have the correct output dimensions) - L.nc() == R.nr() For PLANE_WISE mode: - lhs.num_samples() == rhs.num_samples() && lhs.k() == rhs.k() - If !trans_lhs && !trans_rhs: lhs.nc() == rhs.nr() dest.nr() == lhs.nr() && dest.nc() == rhs.nc() - If trans_lhs && !trans_rhs: lhs.nr() == rhs.nr() dest.nr() == lhs.nc() && dest.nc() == rhs.nc() - If !trans_lhs && trans_rhs: lhs.nc() == rhs.nc() dest.nr() == lhs.nr() && dest.nc() == rhs.nr() - If trans_lhs && trans_rhs: lhs.nr() == rhs.nc() dest.nr() == lhs.nc() && dest.nc() == rhs.nr() ensures - Performs matrix multiplication based on the specified mode: For CHANNEL_WISE mode: - performs: dest = alpha*L*R + beta*mat(dest) where L, R, and D are as defined above. For PLANE_WISE mode: - Performs matrix multiplication for each corresponding 2D plane (nr x nc) in lhs and rhs across all samples and channels. - The operation is equivalent to performing the following for each sample and channel: dest[s][k] = alpha * (lhs[s][k] * rhs[s][k]) + beta * dest[s][k] where [s][k] represents the 2D plane for sample s and channel k. Note that the PLANE_WISE mode is particularly useful for operations like attention mechanisms in neural networks, where you want to perform matrix multiplications on 2D planes of 4D tensors while preserving the sample and channel dimensions. !*/ // ---------------------------------------------------------------------------------------- class inv { /*! WHAT THIS OBJECT REPRESENTS This is a functor for doing matrix inversion on the GPU. The only reason it's an object is to avoid the reallocation of some GPU memory blocks if you want to do a bunch of matrix inversions in a row. !*/ public: void operator() ( const tensor& m, resizable_tensor& out ); /*! requires - m.size() == m.num_samples()*m.num_samples() (i.e. mat(m) must be a square matrix) ensures - out == inv(mat(m)); !*/ private: #ifdef DLIB_USE_CUDA cuda::inv finv; #endif }; // ---------------------------------------------------------------------------------------- class tensor_rand { /*! WHAT THIS OBJECT REPRESENTS This is a tool for filling a tensor with random numbers. Note that the sequence of random numbers output by this object is different when dlib is compiled with DLIB_USE_CUDA. So you should not write code that depends on any specific sequence of numbers coming out of a tensor_rand. !*/ public: // not copyable tensor_rand(const tensor_rand&) = delete; tensor_rand& operator=(const tensor_rand&) = delete; tensor_rand() : tensor_rand(0) {} tensor_rand(unsigned long long seed); void fill_gaussian ( tensor& data, float mean = 0, float stddev = 1 ); /*! requires - data.size()%2 == 0 ensures - Fills data with random numbers drawn from a Gaussian distribution with the given mean and standard deviation. !*/ void fill_uniform ( tensor& data ); /*! ensures - Fills data with uniform random numbers in the range (0.0, 1.0]. !*/ #ifdef DLIB_USE_CUDA cuda::curand_generator rnd; #else dlib::rand rnd; #endif }; // ---------------------------------------------------------------------------------------- void multiply ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ); /*! requires - dest.k() == src1.k() == src2.k() - dest.nr() == src1.nr() == src2.nr() - dest.nc() == src1.nc() == src2.nc() - dest.num_samples(), src1.num_samples(), and src2.num_samples() must each either be 1 or whichever ones aren't equal to 1 must have the same values. ensures - let MD = max(dest.num_samples(), src1.num_samples(), src2.num_samples) - This function pointwise multiplies src1 with src2 and stores the result into #dest. However, how the multiplication happens depends on the dimensions of the tensors. First, when src1 and src2 are multiplied together, if either has a num_samples() dimension that is != MD, then it is first replicated to produce a tensor with num_samples()==MD dimensions and then they are pointwise multiplied together. Second, if dest.num_samples()==1, then after the pointwise multiplication of src1 with src2, the result has its samples summed to produce an output tensor with num_samples()==1 which is then assigned to #dest. - if (add_to) then - Instead of assigning the result to dest, this function adds the result to dest. !*/ void scale_channels ( bool add_to, tensor& dest, const tensor& src, const tensor& scales ); /*! requires - have_same_dimensions(dest, src) == true - scales.num_samples() == src.num_samples() - scales.k() == src.k() - scales.nr() == 1 - scales.nc() == 1 ensures - Scales each channel of src by the corresponding value in scales. To be precise, we will have: - #dest(n,k,r,c) == src(n,k,r,c)*scales(n,k,1,1) - if (add_to) then - Instead of assigning the result to dest, this function adds the result to dest. !*/ void multiply_conv ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ); /*! requires - if (have_same_dimensions(dest, src1) == true) then - src2.num_samples() == 1 - src2.nr() == 1 - src2.nc() == 1 - src2.k() == src1.k() - else - have_same_dimensions(src1, src2) == true) - dest.num_samples() == 1 - dest.nr() == 1 - dest.nc() == 1 - dest.k() == src1.k() ensures - Performs #dest == src1*src2 In particular, if the elements of dest, src1, and src2 were indexed by (n,k,r,c) then we would have: - if (have_same_dimensions(dest,src1)) then #dest(n,k,r,c) == src1(n,k,r,c)*src2(k) - else #dest(k) == sum over {n,r,c} of src1(n,k,r,c)*src2(n,k,r,c) - if (add_to) then - Instead of assigning the result to dest, this function adds the result to dest. !*/ void multiply_zero_padded ( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ); /*! ensures - if (add_to) then - performs: dest += src1 * src2 - else - performs: dest = src1 * src2 - In either case, the multiplication happens pointwise according to 4D tensor arithmetic. If the dimensions don't match then missing elements are presumed to be equal to 0. !*/ // ---------------------------------------------------------------------------------------- void affine_transform( tensor& dest, const tensor& src, const float A, const float B ); /*! requires - dest.size()==src.size() ensures - #dest == A*src + B !*/ void affine_transform( tensor& dest, const tensor& src, const float A ); /*! requires - dest.size()==src.size() ensures - #dest == A*src !*/ void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const float A, const float B, const float C ); /*! requires - dest.size()==src1.size() - dest.size()==src2.size() ensures - #dest == A*src1 + B*src2 + C !*/ void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const float A, const float B ); /*! requires - dest.size()==src1.size() - dest.size()==src2.size() ensures - #dest == A*src1 + B*src2 !*/ void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C, const float D ); /*! requires - dest.size()==src1.size() - dest.size()==src2.size() - dest.size()==src3.size() ensures - #dest == A*src1 + B*src2 + C*src3 + D !*/ void affine_transform( tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C ); /*! requires - dest.size()==src1.size() - dest.size()==src2.size() - dest.size()==src3.size() ensures - #dest == A*src1 + B*src2 + C*src3 !*/ void affine_transform_range( size_t begin, size_t end, tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, const float A, const float B, const float C ); /*! requires - dest.size()==src1.size() - dest.size()==src2.size() - dest.size()==src3.size() - begin <= end <= dest.size() ensures - This function operates much like affine_transform(dest,src1,src2,src3,A,B,C,0), except that it runs over only the half open range [begin,end) rather than processing the entire tensor. Specifically, it does this: - for i in the range [begin, end): - #dest.host()[i] == A*src1.host()[i] + B*src2.host()[i] + C*src3.host()[i] !*/ void affine_transform( const rectangle& rect, tensor& dest, const tensor& src1, const tensor& src2, const tensor& src3, float A, float B, float C ); /*! requires - dest.size()==src1.size() - dest.size()==src2.size() - dest.size()==src3.size() - dest.num_samples()==src1.num_samples() - dest.num_samples()==src2.num_samples() - dest.num_samples()==src3.num_samples() - get_rect(mat(dest)).contains(rect) == true (i.e. rect must be entirely contained within dest) ensures - This function operates much like affine_transform(dest,src1,src2,src3,A,B,C,0), except that it runs over only the sub-rectangle indicated by rect. In particular, this function is equivalent to: set_subm(dest,rect) = A*subm(mat(src1),rect) + B*subm(mat(src2),rect) + C*subm(mat(src3),rect) !*/ // ---------------------------------------------------------------------------------------- void affine_transform( tensor& dest, const tensor& src, const tensor& A, const tensor& B ); /*! requires - have_same_dimensions(dest,src) == true - if (A.num_samples() == 1) then - B.num_samples() == 1 - else - A.num_samples() == src.num_samples() - B.num_samples() == src.num_samples() - A.nr() == B.nr() == src.nr() - A.nc() == B.nc() == src.nc() - A.k() == B.k() == src.k() ensures - if (A.num_samples() == 1) then - #dest == A*src + B (done for each sample in src) - else - for all valid i: - #dest.host()[i] == A.host()[i]*src.host()[i] + B.host()[i] !*/ // ---------------------------------------------------------------------------------------- void affine_transform_conv( tensor& dest, const tensor& src, const tensor& A, const tensor& B ); /*! requires - have_same_dimensions(dest,src) == true - have_same_dimensions(A, B) == true - A.num_samples() == 1 - A.nr() == 1 - A.nc() == 1 - A.k() == src.k() ensures - Performs #dest == A*src + B In particular, if the elements of dest and src were indexed by (n,k,r,c) then we would have: #dest(n,k,r,c) == A(k)*src(n,k,r,c) + B(k). !*/ // ---------------------------------------------------------------------------------------- void compute_adam_update ( size_t begin, size_t end, tensor& s, tensor& m, tensor& v, const float t, const float learning_rate, const float weight_decay, const float momentum1, const float momentum2, const tensor& params, const tensor& params_grad ); /*! requires - s.size() == m.size() = v.size() == params.size() == params_grad.size() - t > 0 - learning_rate > 0 - weight_decay >= 0 - 0 <= momentum1 < 1 - 0 <= momentum2 < 1 - begin <= end <= params.size() ensures - This function implements the ADAM parameter update method described in the paper: Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic optimization." International Conference on Learning Representation. 2015. Specifically, it implements the method shown as Algorithm 1. - #s is the update vector that should be added to the parameters. - The function only operates in the half open range [begin,end) of the memory blocks of each tensor. E.g. to make this function run on the entire tensor set begin to 0 and end to params.size(). !*/ // ---------------------------------------------------------------------------------------- void batch_normalize_inference ( const double eps, resizable_tensor& dest, const tensor& src, const tensor& gamma, const tensor& beta, const tensor& running_means, const tensor& running_variances ); /*! requires - eps > 0 - gamma.num_samples() == 1 - gamma.nr() == src.nr() - gamma.nc() == src.nc() - gamma.k() == src.k() - have_same_dimensions(gamma, beta) - have_same_dimensions(gamma, running_means) - have_same_dimensions(gamma, running_variances) ensures - Linearly transforms src as a call to batch_normalize() would if src had means and variances as given by running_means and running_variances. That is, this function performs: dest = gamma*(src-running_means)/sqrt(running_variances+eps) + beta Note that it does it in a pointwise fashion over the samples in src. !*/ void batch_normalize ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const double averaging_factor, resizable_tensor& running_means, resizable_tensor& running_variances, const tensor& src, const tensor& gamma, const tensor& beta ); /*! requires - eps > 0 - src.num_samples() > 1 - gamma.num_samples() == 1 - beta.num_samples() == 1 - gamma.nr() == beta.nr() == src.nr() - gamma.nc() == beta.nc() == src.nc() - gamma.k() == beta.k() == src.k() - 0 <= averaging_factor <= 1 - if (averaging_factor != 1) - have_same_dimensions(running_means, means) == true - have_same_dimensions(running_variances, invstds) == true ensures - have_same_dimensions(#dest, src) == true - #means.num_samples() == 1 - #invstds.num_samples() == 1 - means.nr() == invstds.nr() == src.nr() - means.nc() == invstds.nc() == src.nc() - means.k() == invstds.k() == src.k() - #dest == the batch normalized version of src. - #means == the mean values of the contents of src. - #invstds == 1/(the standard deviation values of the contents of src). - #running_means = (1-averaging_factor)*mat(#running_means) + averaging_factor*mat(#means); - #running_variances = (1-averaging_factor)*mat(#running_variances) + averaging_factor*(variance of contents of src); !*/ void batch_normalize_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad ); /*! requires - eps > 0 - invstds and means should be the output of a call to batch_normalize(eps,dest,means,invstds,src,gamma,beta) - have_same_dimensions(gradient_input, src) == true - have_same_dimensions(src, src_grad) == true - src.num_samples() > 1 - gamma.num_samples() == 1 - have_same_dimensions(gamma, gamma_grad) == true - have_same_dimensions(gamma, beta_grad) == true - gamma.nr() == src.nr() - gamma.nc() == src.nc() - gamma.k() == src.k() - have_same_dimensions(means, gamma) == true - have_same_dimensions(invstds, gamma) == true ensures - Let f(src,gamma,beta) == dot(gradient_input, dest output of batch_normalize(eps,dest,means,invstds,src,gamma,beta)) - Adds the gradient of f() with respect to src to #src_grad. - Assigns the gradient of f() with respect to gamma to #gamma_grad. - Assigns the gradient of f() with respect to beta to #beta_grad. !*/ // ---------------------------------------------------------------------------------------- void batch_normalize_conv_inference ( const double eps, resizable_tensor& dest, const tensor& src, const tensor& gamma, const tensor& beta, const tensor& running_means, const tensor& running_variances ); /*! requires - eps > 0 - gamma.num_samples() == 1 - gamma.nr() == 1 - gamma.nc() == 1 - gamma.k() == src.k() - have_same_dimensions(gamma, beta) - have_same_dimensions(gamma, running_means) - have_same_dimensions(gamma, running_variances) ensures - Linearly transforms src as a call to batch_normalize_conv() would if src had means and variances as given by running_means and running_variances. That is, this function performs: dest = gamma*(src-running_means)/sqrt(running_variances+eps) + beta Note that it does this in a pointwise fashion over the samples, rows, and columns in src. !*/ void batch_normalize_conv ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const double averaging_factor, resizable_tensor& running_means, resizable_tensor& running_variances, const tensor& src, const tensor& gamma, const tensor& beta ); /*! requires - eps > 0 - src.num_samples() > 1 - gamma.num_samples()==gamma.nr()==gamma.nc() == 1 - beta.num_samples() ==beta.nr() ==gamma.nc() == 1 - gamma.k() == beta.k() == src.k() - 0 <= averaging_factor <= 1 - if (averaging_factor != 1) - have_same_dimensions(running_means, means) == true - have_same_dimensions(running_variances, invstds) == true ensures - have_same_dimensions(#dest, src) == true - #means.num_samples()==means.nr()==means.nc() == 1 - #invstds.num_samples() ==invstds.nr() ==invstds.nc() == 1 - means.k() == invstds.k() == src.k() - #dest == the batch normalized version of src. - #means == the mean values of the contents of src. - #invstds == 1/(the standard deviation values of the contents of src). - #running_means = (1-averaging_factor)*mat(#running_means) + averaging_factor*mat(#means); - #running_variances = (1-averaging_factor)*mat(#running_variances) + averaging_factor*(variance of contents of src); !*/ void batch_normalize_conv_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad ); /*! requires - eps > 0 - invstds and means should be the output of a call to batch_normalize_conv(eps,dest,means,invstds,src,gamma,beta) - have_same_dimensions(gradient_input, src) == true - have_same_dimensions(src, src_grad) == true - src.num_samples() > 1 - gamma.num_samples()==gamma.nr()==gamma.nc() == 1 - have_same_dimensions(gamma, gamma_grad) == true - have_same_dimensions(gamma, beta_grad) == true - gamma.k() == src.k() - have_same_dimensions(means, gamma) == true - have_same_dimensions(invstds, gamma) == true ensures - Let f(src,gamma,beta) == dot(gradient_input, dest output of batch_normalize_conv(eps,dest,means,invstds,src,gamma,beta)) - Adds the gradient of f() with respect to src to #src_grad. - Assigns the gradient of f() with respect to gamma to #gamma_grad. - Assigns the gradient of f() with respect to beta to #beta_grad. !*/ // ----------------------------------------------------------------------------------- void layer_normalize ( const double eps, resizable_tensor& dest, resizable_tensor& means, resizable_tensor& invstds, const tensor& src, const tensor& gamma, const tensor& beta ); /*! requires - eps > 0 - src.k() == gamma.size() == beta.size() - gamma.num_samples() == gamma.nr() == gamma.nc() == 1 - have_same_dimensions(gamma, beta) == true ensures - have_same_dimensions(#dest, src) == true - #means.size() == invstds.size() == src.num_samples() - #dest == the normalized version of src, sample-wise. - #means == the mean values of the contents of src. - #invstds == 1/(the standard deviation values of the contents of src). !*/ void layer_normalize_gradient ( const double eps, const tensor& gradient_input, const tensor& means, const tensor& invstds, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, tensor& beta_grad, resizable_tensor& dmeans, resizable_tensor& dvars ); /*! requires - eps > 0 - invstds and means should be the output of a call to layer_normalize(eps,dest,means,invstds,src,gamma,beta) - have_same_dimensions(gradient_input, src) == true - have_same_dimensions(src, src_grad) == true - have_same_dimensions(gamma, gamma_grad) == true - have_same_dimensions(gamma, beta_grad) == true - means.size() == src.num_samples() - invstds.size() == src.num_samples() ensures - Let f(src,gamma,beta) == dot(gradient_input, dest output of layer_normalize(eps,dest,means,invstds,src,gamma,beta)) - Adds the gradient of f() with respect to src to #src_grad. - Assigns the gradient of f() with respect to gamma to #gamma_grad. - Assigns the gradient of f() with respect to beta to #beta_grad. !*/ // ----------------------------------------------------------------------------------- void rms_normalize( const double eps, resizable_tensor& dest, resizable_tensor& scale, const tensor& src, const tensor& gamma ); /*! requires - eps > 0 - gamma.k() == src.k() - gamma.nr() == 1 - gamma.nc() == 1 ensures - have_same_dimensions(#dest, src) == true - #scale.size() == src.num_samples() - #dest == the RMS normalized version of src - #scale contains the RMS (Root Mean Square) values used to normalize each sample of src. - Each element of #dest is computed as: - #dest[n, k, i, j] == src[n, k, i, j] * gamma[k] / scale[n] where n is the sample index, k is the channel index, and i, j are the spatial indices. !*/ void rms_normalize_gradient( const tensor& gradient_input, const tensor& scale, const tensor& src, const tensor& gamma, tensor& src_grad, tensor& gamma_grad, resizable_tensor& dscale ); /*! requires - scale.size() == src.num_samples() - have_same_dimensions(gamma, gamma_grad) - gamma.k() == src.k() - gamma.nr() == 1 - gamma.nc() == 1 - have_same_dimensions(gradient_input, src) - have_same_dimensions(gradient_input, src_grad) ensures - Let f(src, gamma) == dot(gradient_input, dest output of rms_normalize(eps, dest, scale, src, gamma)) - Adds the gradient of f() with respect to src to #src_grad - Assigns the gradient of f() with respect to gamma to #gamma_grad - #dscale contains the gradients of f() with respect to the RMS values. !*/ // ----------------------------------------------------------------------------------- void threshold ( tensor& data, float thresh ); /*! ensures - Sets all elements of data to 1 or 0 depending on if they are above or below the given threshold. Specifically, for all valid i: - #data.host()[i] == data.host()[i]>thresh ? 1 : 0 !*/ void dot ( const tensor& a, const tensor& b, tensor& result, size_t idx ); /*! requires - a.size() == b.size() - idx < result.size() ensures - #result.host()[idx] == result.host()[idx] + dot(a,b); I.e. Adds the dot product between a and b into the idx-th element of result. The reason you might want to use this more complex version of dot() is because, when using CUDA, it runs by generating asynchronous kernel launches whereas the version of dot() that returns the result immediately as a scalar must block the host while we wait for the result to be computed and then transferred from the GPU do the host for return by dot(). So this version of dot() might be much faster in some cases. !*/ // ---------------------------------------------------------------------------------------- void add( float beta, tensor& dest, float alpha, const tensor& src ); /*! requires - One of the following is true: - have_same_dimensions(src, dest) - src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1 - src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc() - src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc() - src.num_samples()==dest.num_samples() && src.k()==1 && src.nr()==1 && src.nc()==1 - is_same_object(src,dest) == false ensures - performs: dest = beta*dest + alpha*src However, how the addition happens depends on the dimensions of src. In particular, this function adds the scaled values of one src tensor to dest. Each dimension of the src tensor must match the corresponding dimension of the dest tensor or must be equal to 1. In the latter case, the same value from the src tensor, for those dimensions, will be used to add into the dest tensor. !*/ // ---------------------------------------------------------------------------------------- void add ( tensor& dest, const tensor& src1, const tensor& src2 ); /*! ensures - performs: dest = src1 + src2 The addition happens pointwise according to 4D tensor arithmetic. If the dimensions don't match then missing elements are presumed to be equal to 0. !*/ // ---------------------------------------------------------------------------------------- void assign_conv_bias_gradient ( tensor& grad, const tensor& gradient_input ); /*! requires - grad.num_samples() == 1 - grad.k() >= 1 - grad.nr() == 1 - grad.nc() == 1 - gradient_input.k() == grad.k() - gradient_input.size() > 0 - is_same_object(grad,gradient_input) == false ensures - let BIAS be a tensor with the same dimensions as grad. - let OUT be the output of add(1,OUT,1,BIAS) - let f(gradient_input,BIAS) == dot(gradient_input,OUT) - Then this function computes the gradient of f() with respect to BIAS and assigns it to grad. !*/ // ---------------------------------------------------------------------------------------- void assign_bias_gradient ( tensor& grad, const tensor& gradient_input ); /*! requires - grad.num_samples() == 1 - gradient_input.k() == grad.k() - gradient_input.nr() == grad.nr() - gradient_input.nc() == grad.nc() - gradient_input.size() > 0 - is_same_object(grad,gradient_input) == false ensures - let BIAS be a tensor with the same dimensions as grad. - let OUT be the output of add(1,OUT,1,BIAS) - let f(gradient_input,BIAS) == dot(gradient_input,OUT) - Then this function computes the gradient of f() with respect to BIAS and assigns it to grad. !*/ // ---------------------------------------------------------------------------------------- class tensor_conv { public: tensor_conv(const tensor_conv&) = delete; tensor_conv& operator=(const tensor_conv&) = delete; tensor_conv() {} void clear( ) { impl.clear(); } void operator() ( const bool add_to_output, tensor& output, const tensor& data, const tensor& filters ) { impl(add_to_output,output,data,filters); } /*! requires - setup() has been called. Specifically, setup() has been called like this: this->setup(data, filters, stride_y, stride_x, padding_y, padding_x); - is_same_object(output,data) == false - is_same_object(output,filters) == false - filters.k() == data.k() - filters.nr() <= src.nr() + 2*padding_y - filters.nc() <= src.nc() + 2*padding_x - #output.num_samples() == data.num_samples() - #output.k() == filters.num_samples() - #output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y - #output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x ensures - Convolves filters over data. If add_to_output==true then we add the results to output, otherwise we assign to output, overwriting the previous values in output. - filters contains filters.num_samples() filters. !*/ void operator() ( const bool add_to_output, resizable_tensor& output, const tensor& data, const tensor& filters ) { impl(add_to_output,output,data,filters); } /*! requires - setup() has been called. Specifically, setup() has been called like this: this->setup(data, filters, stride_y, stride_x, padding_y, padding_x); - is_same_object(output,data) == false - is_same_object(output,filters) == false - filters.k() == data.k() - filters.nr() <= src.nr() + 2*padding_y - filters.nc() <= src.nc() + 2*padding_x ensures - Convolves filters over data. If add_to_output==true then we add the results to output, otherwise we assign to output, overwriting the previous values in output. - filters contains filters.num_samples() filters. - #output.num_samples() == data.num_samples() - #output.k() == filters.num_samples() - #output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y - #output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x !*/ void operator() ( const bool add_to_output, tensor& output, const tensor& data, const tensor& filters, const tensor& biases, bool use_relu ) { impl(add_to_output,output,data,filters,biases,use_relu); } /*! requires - setup() has been called. Specifically, setup() has been called like this: this->setup(data, filters, stride_y, stride_x, padding_y, padding_x); - is_same_object(output,data) == false - is_same_object(output,filters) == false - filters.k() == data.k() - filters.nr() <= src.nr() + 2*padding_y - filters.nc() <= src.nc() + 2*padding_x - filters.num_samples() == biases.k() - #output.num_samples() == data.num_samples() - #output.k() == filters.num_samples() - #output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y - #output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x ensures - Convolves filters over data. If add_to_output==true then we add the results to output, otherwise we assign to output, overwriting the previous values in output. - Adds biases to the result of the convolved data - filters contains filters.num_samples() filters. - If use_relu==true, then a relu activation will be applied to the result of convolution+bias. !*/ void operator() ( const bool add_to_output, resizable_tensor& output, const tensor& data, const tensor& filters, const tensor& biases, bool use_relu ) { impl(add_to_output,output,data,filters,biases,use_relu); } /*! requires - setup() has been called. Specifically, setup() has been called like this: this->setup(data, filters, stride_y, stride_x, padding_y, padding_x); - is_same_object(output,data) == false - is_same_object(output,filters) == false - filters.k() == data.k() - filters.nr() <= src.nr() + 2*padding_y - filters.nc() <= src.nc() + 2*padding_x - filters.num_samples() == biases.k() ensures - Convolves filters over data. If add_to_output==true then we add the results to output, otherwise we assign to output, overwriting the previous values in output. - Adds biases to the result of the convolved data - filters contains filters.num_samples() filters. - #output.num_samples() == data.num_samples() - #output.k() == filters.num_samples() - #output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y - #output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x !*/ void get_gradient_for_data ( const bool add_to_output, const tensor& gradient_input, const tensor& filters, tensor& data_gradient ) { impl.get_gradient_for_data(add_to_output,gradient_input,filters,data_gradient); } /*! requires - One of the following must be true: - filters has the same dimensions as the filters object given to the last call to operator(). Also, data_gradient has the same dimensions as the data object given to the last call to operator(). - setup() has been called. Specifically, setup() has been called like this: this->setup(data_gradient, filters, stride_y, stride_x, padding_y, padding_x); - gradient_input has the following dimensions: - gradient_input.num_samples() == data_gradient.num_samples() - gradient_input.k() == filters.num_samples() - gradient_input.nr() == 1+(data_gradient.nr() + 2*padding_y - filters.nr())/stride_y - gradient_input.nc() == 1+(data_gradient.nc() + 2*padding_x - filters.nc())/stride_x - NOTE, these dimensions are what you would obtain if gradient_input has the same dimensions as the last output of operator(). - is_same_object(data_gradient,filters) == false - is_same_object(data_gradient,gradient_input) == false ensures - let OUT be the output of (*this)(OUT,data,filters,sx,sy). - let f(data,filters) == dot(OUT, gradient_input) - if (add_to_output) then - This function finds the gradient of f() with respect to data and adds this gradient to data_gradient. - else - This function finds the gradient of f() with respect to data and assigns this gradient to data_gradient, overwriting the previous values in data_gradient. !*/ void get_gradient_for_filters ( const bool add_to_output, const tensor& gradient_input, const tensor& data, tensor& filters_gradient ) { impl.get_gradient_for_filters(add_to_output,gradient_input,data,filters_gradient); } /*! requires - One of the following must be true: - filters_gradient has the same dimensions as the filters object given to the last call to operator(). Also, data has the same dimensions as the data object given to the last call to operator(). - setup() has been called. Specifically, setup() has been called like this: this->setup(data, filters_gradient, stride_y, stride_x, padding_y, padding_x); - gradient_input has the following dimensions: - gradient_input.num_samples() == data.num_samples() - gradient_input.k() == filters.num_samples() - gradient_input.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y - gradient_input.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x - NOTE, these dimensions are what you would obtain if gradient_input has the same dimensions as the last output of operator(). - is_same_object(filters_gradient,data) == false - is_same_object(filters_gradient,gradient_input) == false ensures - let OUT be the output of (*this)(OUT,data,filters,sx,sy). - let f(data,filters) == dot(OUT, gradient_input) - if (add_to_output) then - This function finds the gradient of f() with respect to filters and adds this gradient to filters_gradient. - else - This function finds the gradient of f() with respect to filters and assigns this gradient to filters_gradient, overwriting the previous values in filters_gradient. !*/ void setup( const tensor& data, const tensor& filters, int stride_y, int stride_x, int padding_y, int padding_x ) {impl.setup(data,filters,stride_y,stride_x,padding_y,padding_x); } /*! requires - filters.k() == data.k() - stride_y > 0 - stride_x > 0 - 0 <= padding_y < filters.nr() - 0 <= padding_x < filters.nc() ensures - When operator() is called, the output tensor will have these dimensions: - output.nr() == 1+(data.nr() + 2*padding_y - filters.nr())/stride_y - output.nc() == 1+(data.nc() + 2*padding_x - filters.nc())/stride_x - output.num_samples() == data.num_samples() - output.k() == filters.num_samples() - The point of setup() is to allow this object to gather information about all the tensor sizes and filter layouts involved in the computation. In particular, the reason the tensors are input into setup() is just to observe their sizes. setup() doesn't do anything with the contents of the tensors, or store any kind of references to the data or filter tensors. !*/ private: #ifdef DLIB_USE_CUDA cuda::tensor_conv impl; #else cpu::tensor_conv impl; #endif }; // ---------------------------------------------------------------------------------------- class pooling { /*! WHAT THIS OBJECT REPRESENTS The pooling object is a tool for performing spatial pooling over a tensor. It can be configured to do either max or average pooling. !*/ public: pooling(const pooling&) = delete; pooling& operator=(const pooling&) = delete; pooling ( ) = default; void clear( ) { impl.clear(); } void setup_max_pooling( int window_height, int window_width, int stride_y, int stride_x, int padding_y, int padding_x ) { impl.setup_max_pooling(window_height, window_width, stride_y, stride_x, padding_y, padding_x); } /*! requires - window_height > 0 - window_width > 0 - stride_y > 0 - stride_x > 0 - 0 <= padding_y < window_height - 0 <= padding_x < window_width ensures - When you call operator() it will do max pooling with the given parameters. !*/ void setup_avg_pooling( int window_height, int window_width, int stride_y, int stride_x, int padding_y, int padding_x ) { impl.setup_avg_pooling(window_height, window_width, stride_y, stride_x, padding_y, padding_x); } /*! requires - window_height > 0 - window_width > 0 - stride_y > 0 - stride_x > 0 - 0 <= padding_y < window_height - 0 <= padding_x < window_width ensures - When you call operator() it will do average pooling with the given parameters. !*/ bool does_max_pooling( ) const { return impl.does_max_pooling(); } void operator() ( resizable_tensor& dest, const tensor& src ) { impl(dest, src); } /*! requires - is_same_object(dest,src) == false - either setup_max_pooling() or setup_avg_pooling() has been called. - window_width <= src.nc() + 2*padding_x - window_height <= src.nr() + 2*padding_y ensures - #dest.num_samples() == src.num_samples() - #dest.k() == src.k() - #dest.nr() == 1 + (src.nr() + 2*padding_y - window_height)/stride_y - #dest.nc() == 1 + (src.nc() + 2*padding_x - window_width)/stride_x - WINDOW == centered_rect(x*stride_x + window_width/2 - padding_x, y*stride_y + window_height/2 - padding_y, window_width, window_height) - for all valid s, k, r, and c: - if (does_max_pooling()) then - image_plane(#dest,s,k)(r,c) == max(subm_clipped(image_plane(src,s,k),WINDOW(c,r))) - else - image_plane(#dest,s,k)(r,c) == mean(subm_clipped(image_plane(src,s,k),WINDOW(c,r))) !*/ void get_gradient( const tensor& gradient_input, const tensor& dest, const tensor& src, tensor& grad ) { impl.get_gradient(gradient_input, dest, src, grad); } /*! requires - have_same_dimensions(gradient_input,dest) == true - have_same_dimensions(src,grad) == true - dest contains the result of calling (*this)(dest,src) - is_same_object(grad,gradient_input) == false - is_same_object(grad,dest) == false - is_same_object(grad,src) == false ensures - Recalling that dest is the output of (*this)(dest,src), let f(src) == dot(gradient_input,dest) - Then this function computes the gradient of f() with respect to src and adds it to grad. !*/ private: #ifdef DLIB_USE_CUDA cuda::pooling impl; #else cpu::pooling impl; #endif }; // ---------------------------------------------------------------------------------------- void softmax( tensor& dest, const tensor& src, operation_mode mode = operation_mode::CHANNEL_WISE ); /*! requires - have_same_dimensions(dest, src) == true - mode == CHANNEL_WISE || mode == PLANE_WISE ensures - Note that the softmax function is a vector valued function: s(x) == exp(x)/sum(exp(x)) - Computes the softmax function on src and writes the results to dest. - If mode == CHANNEL_WISE: The softmax is computed per spatial location across the different channels at each location. That is, softmax() outputs a new tensor, #dest, where each of the spatial locations in dest (i.e. image idx, row idx, and column idx) contains the output of s() evaluated over the channel values at each location. - If mode == PLANE_WISE: The softmax is computed across entire planes (nr x nc) of the input tensor. This is useful for operations in Large Language Models (LLMs) and other applications requiring 2D tensor processing. - This function supports in-place operation, i.e. having is_same_object(dest, src)==true !*/ void softmax_gradient( tensor& grad, const tensor& dest, const tensor& gradient_input, operation_mode mode = operation_mode::CHANNEL_WISE ); /*! requires - have_same_dimensions(dest,gradient_input) == true - have_same_dimensions(dest,grad) == true - mode == CHANNEL_WISE || mode == PLANE_WISE ensures - We interpret dest as the output of softmax(dest,SRC,mode) for some SRC tensor. Then let f(SRC) == dot(gradient_input,dest). Then this function computes the gradient of f() with respect to SRC and stores it to grad. Moreover, if is_same_object(grad,gradient_input)==true then the output is assigned to grad, replacing its previous contents. Otherwise the output is added to grad. - The gradient computation takes into account the specified mode: - If mode == CHANNEL_WISE: The gradient is computed per spatial location across channels. - If mode == PLANE_WISE: The gradient is computed across entire planes of the tensor. - This function supports in-place operation, i.e. having is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- void softmax_all ( tensor& dest, const tensor& src ); /*! requires - have_same_dimensions(dest, src) == true ensures - Note that the softmax function is a vector valued function: s(x) == exp(x)/sum(exp(x)) - Computes the softmax function on src and writes the results to dest. The softmax is computed over the entire tensor with one invocation of s(). So unlike softmax() which computes many s() evaluations, one for each spatial location, softmax_all() calls s() once for the entire tensor. - This function supports in-place operation, i.e. having is_same_object(dest, src)==true !*/ void softmax_all_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); /*! requires - have_same_dimensions(dest,gradient_input) == true - have_same_dimensions(dest,grad) == true - is_same_object(grad, dest)==false ensures - We interpret dest as the output of softmax_all(dest,SRC) for some SRC tensor. Then let f(SRC) == dot(gradient_input,dest) Then this function computes the gradient of f() with respect to SRC and assigns it to grad. - This function supports in-place operation, i.e. having is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- void sigmoid ( tensor& dest, const tensor& src ); /*! requires - have_same_dimensions(dest, src) == true ensures - for all valid i: - #dest.host()[i] == 1/(1+std::exp(-src.host()[i])) - This function supports in-place operation, i.e. having is_same_object(dest, src)==true !*/ void sigmoid_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); /*! requires - have_same_dimensions(dest,gradient_input) == true - have_same_dimensions(dest,grad) == true ensures - Recalling that dest is the output of sigmoid(dest,SRC) for some SRC tensor, let f(SRC) == dot(gradient_input,dest). Then this function computes the gradient of f() with respect to SRC and stores it to grad. Moreover, if is_same_object(grad,gradient_input)==true then the output is assigned to grad, replacing its previous contents. Otherwise the output is added to grad. - This function supports in-place operation, i.e. having is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- void mish ( tensor& dest, const tensor& src ); /*! requires - have_same_dimensions(dest, src) == true ensures - for all valid i: - #dest.host()[i] == src.host()[i]*std::tanh(std::log(1+std::exp(src.host()[i]))) - This function supports in-place operation, i.e. having is_same_object(dest, src)==true !*/ void mish_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); /*! requires - have_same_dimensions(dest,gradient_input) == true - have_same_dimensions(dest,grad) == true ensures - This function computes the gradient of f() with respect to SRC and stores it to grad. Moreover, if is_same_object(grad,gradient_input)==true then the output is assigned to grad, replacing its previous contents. Otherwise the output is added to grad. - This function supports in-place operation, i.e. having is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- void relu ( tensor& dest, const tensor& src ); /*! requires - have_same_dimensions(dest, src) == true ensures - for all valid i: - #dest.host()[i] == std::max(0,src.host()[i]) - This function supports in-place operation, i.e. having is_same_object(dest, src)==true !*/ void relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); /*! requires - have_same_dimensions(dest,gradient_input) == true - have_same_dimensions(dest,grad) == true ensures - Recalling that dest is the output of relu(dest,SRC) for some SRC tensor, let f(SRC) == dot(gradient_input,dest). Then this function computes the gradient of f() with respect to SRC and stores it to grad. Moreover, if is_same_object(grad,gradient_input)==true then the output is assigned to grad, replacing its previous contents. Otherwise the output is added to grad. - This function supports in-place operation, i.e. having is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- void prelu ( tensor& dest, const tensor& src, const tensor& param ); /*! requires - have_same_dimensions(dest, src) == true - param.size() == 1 ensures - for all valid i: - if (src.host()[i] > 0) then - #dest.host()[i] == src.host()[i] - else - #dest.host()[i] == src.host()[i] * param.host()[0] - This function supports in-place operation, i.e. having is_same_object(dest, src)==true !*/ void prelu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input, const tensor& param, tensor& params_grad ); /*! requires - have_same_dimensions(grad,src) == true - have_same_dimensions(grad,gradient_input) == true - param.size() == 1 - params_grad.size() == 1 - is_same_object(grad, gradient_input) == false ensures - Recalling that dest is the output of prelu(dest,src,param) let f(src,param) == dot(gradient_input,dest) - Then this function computes the gradient of f() with respect to src and param. It assigns the gradient with respect to param to #params_grad and adds the gradient with respect to src to #grad. !*/ // ---------------------------------------------------------------------------------------- void leaky_relu ( tensor& dest, const tensor& src, const float alpha ); /*! requires - have_same_dimensions(dest, src) == true ensures - for all valid i: - if (src.host()[i] > 0) then - #dest.host()[i] == src.host()[i] - else - #dest.host()[i] == src.host()[i] * alpha !*/ void leaky_relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float alpha ); /*! requires - have_same_dimensions(dest,gradient_input) == true - have_same_dimensions(dest,grad) == true ensures - Recalling that dest is the output of leaky_relu(dest,SRC) for some SRC tensor, let f(SRC) == dot(gradient_input,dest). Then this function computes the gradient of f() with respect to SRC and stores it to grad. Moreover, if is_same_object(grad,gradient_input)==true then the output is assigned to grad, replacing its previous contents. Otherwise the output is added to grad. - This function supports in-place operation, i.e. having is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- void tanh ( tensor& dest, const tensor& src ); /*! requires - have_same_dimensions(dest, src) == true ensures - for all valid i: - #dest.host()[i] == std::tanh(src.host()[i]) - This function supports in-place operation, i.e. having is_same_object(dest, src)==true !*/ void tanh_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input ); /*! requires - have_same_dimensions(dest,gradient_input) == true - have_same_dimensions(dest,grad) == true ensures - Recalling that dest is the output of tanh(dest,SRC) for some SRC tensor, let f(SRC) == dot(gradient_input,dest). Then this function computes the gradient of f() with respect to SRC and stores it to grad. Moreover, if is_same_object(grad,gradient_input)==true then the output is assigned to grad, replacing its previous contents. Otherwise the output is added to grad. - This function supports in-place operation, i.e. having is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- void clipped_relu ( tensor& dest, const tensor& src, const float ceiling ); /*! requires - have_same_dimensions(dest, src) == true ensures - for all valid i: - #dest.host()[i] == std::min(std::max(src.host()[i], 0), ceiling) - This function supports in-place operation, i.e. having is_same_object(dest, src)==true !*/ void clipped_relu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float ceiling ); /*! requires - have_same_dimensions(dest,gradient_input) == true - have_same_dimensions(dest,grad) == true ensures - Recalling that dest is the output of clipped_relu(dest,SRC,ceiling) for some SRC tensor, let f(SRC) == dot(gradient_input,dest). Then this function computes the gradient of f() with respect to SRC and stores it to grad. Moreover, if is_same_object(grad,gradient_input)==true then the output is assigned to grad, replacing its previous contents. Otherwise the output is added to grad. - This function supports in-place operation, i.e. having is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- void elu ( tensor& dest, const tensor& src, const float alpha ); /*! requires - have_same_dimensions(dest, src) == true ensures - for all valid i: - if (src.host()[i] > 0) then - #dest.host()[i] == src.host()[i] - else - #dest.host()[i] == alpha * (std::exp(src.host()[i]) - 1) - This function supports in-place operation, i.e. having is_same_object(dest, src)==true !*/ void elu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float alpha ); /*! requires - have_same_dimensions(dest,gradient_input) == true - have_same_dimensions(dest,grad) == true ensures - Recalling that dest is the output of elu(dest,SRC) for some SRC tensor, let f(SRC) == dot(gradient_input,dest). Then this function computes the gradient of f() with respect to SRC and stores it to grad. Moreover, if is_same_object(grad,gradient_input)==true then the output is assigned to grad, replacing its previous contents. Otherwise the output is added to grad. - This function supports in-place operation, i.e. having is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- void gelu ( tensor& dest, const tensor& src ); /*! requires - have_same_dimensions(dest, src) == true ensures - for all valid i: - #dest.host()[i] == src.host()[i]/2 * (1 + erf(src.host()[i]/sqrt(2)) - This function supports in-place operation, i.e. having is_same_object(dest, src)==true !*/ void gelu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ); /*! requires - have_same_dimensions(src,gradient_input) == true - have_same_dimensions(src,grad) == true ensures - Recalling that dest is the output of gelu(dest,src), let f(src) == dot(gradient_input,dest). Then this function computes the gradient of f() with respect to src and stores it to grad. Moreover, if is_same_object(grad,gradient_input)==true then the output is assigned to grad, replacing its previous contents. Otherwise the output is added to grad. - This function supports in-place operation, i.e. having is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- void smelu ( tensor& dest, const tensor& src, const float beta ); /*! requires - have_same_dimensions(dest, src) == true - beta > 0 ensures - for all valid i: - if (src.host()[i] > beta) then - #dest.host()[i] == src.host()[i] - else if (src.host()[i] < -beta) then - #dest.host()[i] == 0 - else - #dest.host()[i] == std::pow(src.host()[i] + beta), 2) / (4 * beta) !*/ void smelu_gradient ( tensor& grad, const tensor& dest, const tensor& gradient_input, const float beta ); /*! requires - have_same_dimensions(dest,gradient_input) == true - have_same_dimensions(dest,grad) == true - beta > 0 ensures - Recalling that dest is the output of smelu(dest,SRC) for some SRC tensor, let f(SRC) == dot(gradient_input,dest). Then this function computes the gradient of f() with respect to SRC and stores it to grad. Moreover, if is_same_object(grad,gradient_input)==true then the output is assigned to grad, replacing its previous contents. Otherwise the output is added to grad. - This function supports in-place operation, i.e. having is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- void silu ( tensor& dest, const tensor& src ); /*! requires - have_same_dimensions(dest, src) == true ensures - for all valid i: - #dest.host()[i] == src.host()[i] * sigmoid(src.host()[i]) - This function supports in-place operation, i.e. having is_same_object(dest, src)==true !*/ void silu_gradient ( tensor& grad, const tensor& src, const tensor& gradient_input ); /*! requires - have_same_dimensions(src,gradient_input) == true - have_same_dimensions(src,grad) == true ensures - Recalling that dest is the output of silu(dest,src), let f(src) == dot(gradient_input,dest). Then this function computes the gradient of f() with respect to src and stores it to grad. Moreover, if is_same_object(grad,gradient_input)==true then the output is assigned to grad, replacing its previous contents. Otherwise the output is added to grad. - This function supports in-place operation, i.e. having is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- void resize_bilinear ( tensor& dest, long dest_row_stride, long dest_channel_stride, const tensor& src, long src_row_stride, long src_channel_stride ); /*! requires - is_same_object(dest, src)==false - dest.num_samples() == src.num_samples() - dest.k() == src.k() ensures - for all valid i,k: image_plane(dest,i,k) is a copy of image_plane(src,i,k) that has been bilinearly interpolated to fit into the shape of image_plane(dest,i,k). - Instead of supposing the row stride and channel stride in the tensors is given by tensor::nc() and tensor::nr()*tensor::nc() respectively, we use the provided stride values to transition from one row and channel to the next. This is useful in combination with alias_tensor objects since it allows you to operate on subwindows in an image. !*/ void resize_bilinear_gradient ( tensor& grad, long grad_row_stride, long grad_channel_stride, const tensor& gradient_input, long gradient_input_row_stride, long gradient_input_channel_stride ); /*! requires - is_same_object(grad, gradient_input)==false - gradient_input.num_samples() == grad.num_samples() - gradient_input.k() == grad.k() ensures - Suppose that DEST is the output of resize_bilinear(DEST,SRC) for some SRC tensor, let f(SRC) == dot(gradient_input,DEST). Then this function computes the gradient of f() with respect to SRC and adds it to grad. It should be noted that we don't need to know the contents of DEST to compute this gradient. All that matters is that gradient_input have the same dimensions as DEST. - Instead of supposing the row stride and channel stride in the tensors is given by tensor::nc() and tensor::nr()*tensor::nc() respectively, we use the provided stride values to transition from one row and channel to the next. This is useful in combination with alias_tensor objects since it allows you to operate on subwindows in an image. !*/ inline void resize_bilinear ( tensor& dest, const tensor& src ) { resize_bilinear(dest, dest.nc(), dest.nr()*dest.nc(), src, src.nc(), src.nr()*src.nc()); } /*! requires - is_same_object(dest, src)==false - dest.num_samples() == src.num_samples() - dest.k() == src.k() ensures - for all valid i,k: image_plane(dest,i,k) is a copy of image_plane(src,i,k) that has been bilinearly interpolated to fit into the shape of image_plane(dest,i,k). !*/ inline void resize_bilinear_gradient ( tensor& grad, const tensor& gradient_input ) { resize_bilinear_gradient(grad, grad.nc(), grad.nr()*grad.nc(), gradient_input, gradient_input.nc(), gradient_input.nr()*gradient_input.nc()); } /*! requires - is_same_object(grad, gradient_input)==false - gradient_input.num_samples() == grad.num_samples() - gradient_input.k() == grad.k() ensures - Suppose that DEST is the output of resize_bilinear(DEST,SRC) for some SRC tensor, let f(SRC) == dot(gradient_input,DEST). Then this function computes the gradient of f() with respect to SRC and adds it to grad. It should be noted that we don't need to know the contents of DEST to compute this gradient. All that matters is that gradient_input have the same dimensions as DEST. !*/ // ---------------------------------------------------------------------------------------- void reorg ( bool add_to, tensor& dest, const int row_stride, const int col_stride, const tensor& src ); /*! requires - !is_same_object(dest, src) - src.nr() % row_stride == 0 - src.nc() % col_stride == 0 - dest.num_samples() == src.num_samples() - dest.k() == src.k() * row_stride * col_stride - dest.nr() == src.nr() / row_stride - dest.nc() == src.nc() / col_stride ensures - Reorganizes the spatial resolution of src into channel information in dest, effectively shifting spatial data into the channel dimension based on the specified strides. - If add_to is false: - Each element in dest is set to the corresponding reorganized value from src. - If add_to is true: - Each element in dest is incremented by the corresponding reorganized value from src. - Specifically, for all n, k, r, c in dest: - If add_to is false: dest.host[tensor_index(dest, n, k, r, c)] = src.host[tensor_index(src, n, k % src.k(), r * row_stride + (k / src.k()) / col_stride, c * col_stride + (k / src.k()) % col_stride)]; - If add_to is true: dest.host[tensor_index(dest, n, k, r, c)] += src.host[tensor_index(src, n, k % src.k(), r * row_stride + (k / src.k()) / col_stride, c * col_stride + (k / src.k()) % col_stride)]; !*/ void reorg_gradient ( bool add_to, tensor& grad, const int row_stride, const int col_stride, const tensor& gradient_input ); /*! requires - !is_same_object(grad, gradient_input) - gradient_input.nr() % row_stride == 0 - gradient_input.nc() % col_stride == 0 - grad.num_samples() == gradient_input.num_samples() - grad.k() == gradient_input.k() / row_stride / col_stride - grad.nr() == gradient_input.nr() * row_stride - grad.nc() == gradient_input.nc() * col_stride ensures - Computes the gradient of the function f(SRC) = DEST, where DEST is the result of reorg(DEST, row_stride, col_stride, SRC). - If add_to is false: - Each element in grad is set to the corresponding gradient value. - If add_to is true: - Each element in grad is incremented by the corresponding gradient value. - Specifically, for all n, k, r, c in grad: - If add_to is false: grad.host[tensor_index(grad, n, k, r, c)] = gradient_input.host[tensor_index(gradient_input, n, (k*row_stride*col_stride) + (r%row_stride)*col_stride + c%col_stride, r/row_stride, c/col_stride)]; - If add_to is true: grad.host[tensor_index(grad, n, k, r, c)] += gradient_input.host[tensor_index(gradient_input, n, (k*row_stride*col_stride) + (r%row_stride)*col_stride + c%col_stride, r/row_stride, c/col_stride)]; - This function effectively reverses the reorg operation, distributing gradients from the channel dimension of gradient_input to the spatial dimensions of grad. !*/ // ---------------------------------------------------------------------------------------- void embeddings( resizable_tensor& dest, const tensor& src, const tensor& embs ); /*! requires - src.nr() > 0 - embs.num_samples() > 0 - embs.k() > 0 - embs.nr() == 1 - embs.nc() == 1 - dest.num_samples() == src.num_samples() - dest.k() == src.k() - dest.nr() == src.nr() - dest.nc() == embs.k() ensures - Projects tokens from the input tensor `src` into embeddings stored in `embs`. - The resulting embeddings are stored in the `dest` tensor. - For all valid s (0 <= s < dest.num_samples()), k (0 <= k < dest.k()), r (0 <= r < dest.nr()), c (0 <= c < dest.nc()): - Let token_idx = static_cast(src(s,k,r,0)) - If token_idx < embs.num_samples(): - #dest(s,k,r,c) = embs(token_idx, c, 0, 0) - Else: - #dest(s,k,r,c) = 0 - The function iterates over all elements of src and populates dest accordingly. - If a token index in src is out of range (>= embs.num_samples()), the corresponding embedding in dest is filled with 0's. */ void embeddings_gradient( const tensor& prev, const tensor& gradient_input, tensor& grads, const tensor& freqs, float learning_rate, bool scale ); /*! requires - prev.nr() > 0 - gradient_input.num_samples() == prev.num_samples() - gradient_input.k() == prev.k() - gradient_input.nr() == prev.nr() - gradient_input.nc() == grads.k() - grads.num_samples() > 0 - grads.k() > 0 - grads.nr() == 1 - grads.nc() == 1 - freqs.num_samples() == grads.num_samples() - freqs.k() == 1 - freqs.nr() == 1 - freqs.nc() == 1 ensures - Updates the `grads` tensor based on the gradients in `gradient_input`. - For each sample s, channel k, and row r in prev: - Retrieves the token index from prev[s,k,r,0] - If the token index is valid (< grads.num_samples()): - If scale is true: - Computes a frequency scale factor based on freqs[token_idx] - The scale factor is min(0.15, max(1.0 / freqs[token_idx], 1.0)) - For each column c in gradient_input: - Updates grads[token_idx, c] -= gradient_input[s,k,r,c] * learning_rate * freq_scale - The updates to grads are performed atomically to handle concurrent updates to the same embedding. - The function is thread-safe and processes samples in parallel. */ // ---------------------------------------------------------------------------------------- class multi_device_tensor_averager { /*! WHAT THIS OBJECT REPRESENTS This object is a tool for very quickly averaging a bunch of tensors together. !*/ public: multi_device_tensor_averager(const multi_device_tensor_averager&) = delete; multi_device_tensor_averager& operator=(const multi_device_tensor_averager&) = delete; multi_device_tensor_averager() = default; void set( std::vector items ) /*! requires - All the tensors in items are the same size ensures - When you call average() we will average the tensors in items. - It's important that the tensors already be allocated to their devices before you call set(). This is because set() will setup the types of between device transfers now and use them when you call average(). !*/ { using namespace ::dlib::cuda; accessible_groups.clear(); epa.clear(); if (items.size() < 1) return; scale = 1.f/items.size(); // split item into groups of accessible devices std::vector group, unused; while(items.size() > 0) { group.push_back(items[0]); for(size_t i = 1; i < items.size(); ++i) { if (can_access_peer(*items[0], *items[i])) group.push_back(items[i]); else unused.push_back(items[i]); } accessible_groups.push_back(group); unused.swap(items); unused.clear(); group.clear(); } for (auto&& g : accessible_groups) { for (size_t i = 1; i < g.size(); ++i) { epa.emplace_back(new enable_peer_access(*g[0], *g[i])); } } } size_t num_device_groups( ) const { return accessible_groups.size(); } /*! ensures - The devices given to set() are grouped together when they can directly access each other using GPUDirect. This function returns the number of such groups. For example, if all devices can directly access each other then the number of groups is 1. !*/ void average() /*! requires - All the devices have stopped writing to the tensors given to set(). So you should probably call cudaDeviceSynchronize() on each of the relevant devices before calling average(). ensures - Computes the average of all the tensors given to set() and then sets them all equal to the average. !*/ { using namespace ::dlib::cuda; // First we average things within each group for (auto&& g : accessible_groups) { raii_set_device set_dev(*g[0]); if (g.size() == 1) tt::affine_transform(*g[0], *g[0], scale); else tt::affine_transform(*g[0], *g[0], *g[1], scale, scale); for (size_t i = 2; i < g.size(); ++i) tt::affine_transform(*g[0], *g[0], *g[i], 1, scale); } if (accessible_groups.size() > 1) { tensor& total_avg = *accessible_groups[0][0]; raii_set_device set_dev(total_avg); accum_buffer.copy_size(total_avg); // now we need to average things across groups for (size_t i = 1; i < accessible_groups.size(); ++i) { memcpy(accum_buffer, *accessible_groups[i][0]); tt::add(total_avg, total_avg, accum_buffer); } // Now total_avg has the final average in it. So we need to send // copies of it back to each of the groups. for (size_t i = 1; i < accessible_groups.size(); ++i) { memcpy(*accessible_groups[i][0], total_avg); } } // Now propagate averages back out to each element using point to point // communication inside a group. for (auto&& g : accessible_groups) { raii_set_device set_dev(*g[0]); for (size_t i = 1; i < g.size(); ++i) memcpy(*g[i], *g[0]); } } private: std::vector> epa; std::vector> accessible_groups; float scale; resizable_tensor accum_buffer; }; // ---------------------------------------------------------------------------------------- void copy_tensor( bool add_to, tensor& dest, size_t dest_k_offset, const tensor& src, size_t src_k_offset, size_t count_k ); /*! requires - dest.nc() == src.nc() - dest.nr() == src.nr() - dest.num_samples() == src.num_samples() - dest.k() - dest_k_offset >= count_k - src.k() - src_k_offset >= count_k - is_same_object(dest,src) == false - The memory areas of src and dest do not overlap. ensures - if (add_to) then - performs: dest[i, k + dest_k_offset, r, c] += src[i, k + src_k_offset, r, c], where k in [0..count_k] i.e., adds content of each sample from src in to corresponding place of sample at dest. - else - performs: dest[i, k + dest_k_offset, r, c] = src[i, k + src_k_offset, r, c], where k in [0..count_k] i.e., copies content of each sample from src in to corresponding place of sample at dest. !*/ // ---------------------------------------------------------------------------------------- void copy_tensor( bool add_to, tensor& dest, size_t dk, size_t dnr, size_t dnc, const tensor& src, size_t sk, size_t snr, size_t snc, size_t k, size_t nr, size_t nc ); /*! requires - dest.num_samples() == src.num_samples() - dest.k() - dk >= k - dest.nr() - dnr >= nr - dest.nc() - dnc >= nc - src.k() - sk >= k - src.nr() - snr >= nr - src.nc() - snc >= nc - is_same_object(dest,src) == false - The memory areas of src and dest do not overlap. ensures - if (add_to) then - performs: dest[i, j + dk, r + dnr, c + dnc] += src[i, j + sk, r + snr, c + snc], where j in [0..k], r in [0..nr] and c in [0..nc] i.e., adds content of each sample from src in to corresponding place of sample at dest. - else - performs: dest[i, j + dk, r + dnr, c + dnc] = src[i, j + sk, r + snr, c +snc], where j in [0..k], r in [0..nr] and c in [0..nc] i.e., copies content of each sample from src in to corresponding place of sample at dest. !*/ // ---------------------------------------------------------------------------------------- void transpose( bool add_to, tensor& dest, const tensor& src ); /*! requires - is_same_object(dest, src) == false - dest.num_samples() == src.num_samples() - dest.k() == src.k() - dest.nr() == src.nc() - dest.nc() == src.nr() ensures - Performs a transpose operation on the nr() x nc() matrices within src. - If (add_to) is false: - The result is stored in dest, overwriting its previous contents. - For all valid n, k, r, c: - #dest(n,k,c,r) == src(n,k,r,c) - If (add_to) is true: - The result is added to the existing contents of dest. - For all valid n, k, r, c: - #dest(n,k,c,r) == dest(n,k,c,r) + src(n,k,r,c) !*/ // ---------------------------------------------------------------------------------------- // ACT (Adaptive Computation Time) operations void compute_act_halt_probabilities( resizable_tensor& halt_probs, resizable_tensor& logits, const tensor& input_data, const tensor& halt_params, long batch_size, long seq_len, long feature_dim ); /*! requires - halt_params.size() == feature_dim + 1 (weights + bias) - input_data.num_samples() == batch_size - input_data.k() == num_channels where feature_dim = num_channels * d_model - input_data.nr() == seq_len - input_data.nc() == d_model ensures - Computes halting probabilities for Adaptive Computation Time: - halt_probs contains sigmoid(W_halt^T * input + b_halt) for each position - logits contains the pre-sigmoid values - batch_size: number of samples in the batch - seq_len: sequence length (number of positions to process) - feature_dim: total feature dimension (num_channels × d_model) !*/ void update_act_state( resizable_tensor& output, const tensor& input_data, const tensor& halt_probs, resizable_tensor& cumulative_halting, resizable_tensor& remainders, resizable_tensor& n_steps, resizable_tensor& effective_weights, long batch_size, long seq_len, long d_model, long num_channels, float halt_threshold, long current_step ); /*! requires - 0 < halt_threshold <= 1.0 - current_step >= 0 - input_data.num_samples() == batch_size - input_data.k() == num_channels - input_data.nr() == seq_len - input_data.nc() == d_model - output has the same dimensions as input_data - halt_probs.size() == batch_size * seq_len - cumulative_halting.size() == remainders.size() == n_steps.size() == effective_weights.size() == batch_size * seq_len ensures - Core ACT update step that accumulates weighted outputs: - Updates ACT state for all positions - Accumulates weighted outputs: output += α_t^n * input_data - Updates cumulative_halting, remainders, n_steps, and effective_weights - batch_size: number of samples in the batch - seq_len: sequence length (number of positions to process) - d_model: model dimension per channel - num_channels: number of feature channels - halt_threshold: halting threshold (typically 0.99) - current_step: current computation step index (0-based) !*/ void finalize_act_output( resizable_tensor& output, const tensor& input_data, const tensor& remainders, resizable_tensor& effective_weights, long batch_size, long seq_len, long d_model, long num_channels ); /*! requires - input_data.num_samples() == batch_size - input_data.k() == num_channels - input_data.nr() == seq_len - input_data.nc() == d_model - output has the same dimensions as input_data - remainders.size() == effective_weights.size() == batch_size * seq_len ensures - Finalizes ACT output by adding remainder contributions: - Adds final remainder contributions: output += ρ_t * input_data - Updates effective_weights with remainder values - Applied only to positions with significant remainder (> 1e-6) - batch_size: number of samples in the batch - seq_len: sequence length (number of positions to process) - d_model: model dimension per channel - num_channels: number of feature channels !*/ void apply_act_depth_scaling( tensor& gradients, const tensor& n_steps, long batch_size, long seq_len, long d_model, long num_channels, float max_steps, float scale_factor ); /*! requires - scale_factor >= 0 - max_steps > 0 - gradients.num_samples() == batch_size - gradients.k() == num_channels - gradients.nr() == seq_len - gradients.nc() == d_model - n_steps.size() == batch_size * seq_len ensures - Applies gradient scaling based on computation depth: - Applies depth-dependent gradient scaling - scale = 1 + scale_factor * (n_steps[pos] / max_steps) - seq_len: sequence length (number of positions to process) - d_model: model dimension per channel - num_channels: number of feature channels - max_steps: maximum allowed computation steps - scale_factor: scaling strength (0 = no scaling) !*/ // ---------------------------------------------------------------------------------------- }} #ifdef NO_MAKEFILE #include "tensor_tools.cpp" #endif #endif // DLIB_TeNSOR_TOOLS_H_ ================================================ FILE: dlib/data_io/arc_agi.h ================================================ // Copyright (C) 2025 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ARC_AGI_H_ #define DLIB_ARC_AGI_H_ #include "arc_agi_abstract.h" #include #include #include #include #include #include #include #include #include "../matrix.h" #include "../dir_nav.h" #include "../serialize.h" namespace dlib { // ---------------------------------------------------------------------------------------- // Type aliases and constants // ---------------------------------------------------------------------------------------- /*! Type aliases for ARC-AGI data structures. Grids are represented as matrices of unsigned char values (0-9), and token sequences are column vectors of long. !*/ using arc_grid_t = matrix; using arc_token_sequence_t = matrix; /*! Maximum sequence length for LLM-style training. This constant defines the upper bound for token sequences that can be processed by the model. !*/ constexpr long ARC_MAX_SEQUENCE_LENGTH = 4096; // ---------------------------------------------------------------------------------------- // Token vocabulary // ---------------------------------------------------------------------------------------- /*! Token vocabulary for the Hierarchical Reasoning Model. The vocabulary includes: - COLOR_0 to COLOR_9: Grid cell colors (10 values) - TOKEN_SEP_IO: Separator between input and output grids - TOKEN_SEP_PAIR: Separator between demonstration pairs - TOKEN_QUERY_START: Marks the beginning of a test query - TOKEN_GEN_START: Marks the beginning of generation phase - TOKEN_END_OF_OUTPUT: Marks the end of generated output - TOKEN_PADDING: Padding token for variable-length sequences - TOKEN_ROW_END: Marks the end of a grid row (for dimension encoding) !*/ enum arc_token_id : long { COLOR_0 = 0, COLOR_1 = 1, COLOR_2 = 2, COLOR_3 = 3, COLOR_4 = 4, COLOR_5 = 5, COLOR_6 = 6, COLOR_7 = 7, COLOR_8 = 8, COLOR_9 = 9, TOKEN_SEP_IO = 10, TOKEN_SEP_PAIR = 11, TOKEN_QUERY_START = 12, TOKEN_GEN_START = 13, TOKEN_END_OF_OUTPUT = 14, TOKEN_PADDING = 15, TOKEN_ROW_END = 16 }; /*! Vocabulary size constants for the token set. !*/ constexpr long ARC_VOCAB_SIZE_COLORS = 10; constexpr long ARC_VOCAB_SIZE_TOTAL = 17; // ---------------------------------------------------------------------------------------- // ARC-AGI task data structures // ---------------------------------------------------------------------------------------- /*! Represents a single input-output pair in an ARC-AGI task. Each pair consists of an input grid and its corresponding output grid, along with their dimensions. !*/ struct arc_task_pair { arc_grid_t input; arc_grid_t output; long input_rows; long input_cols; long output_rows; long output_cols; friend void serialize(const arc_task_pair& item, std::ostream& out) { dlib::serialize(item.input, out); dlib::serialize(item.output, out); dlib::serialize(item.input_rows, out); dlib::serialize(item.input_cols, out); dlib::serialize(item.output_rows, out); dlib::serialize(item.output_cols, out); } friend void deserialize(arc_task_pair& item, std::istream& in) { dlib::deserialize(item.input, in); dlib::deserialize(item.output, in); dlib::deserialize(item.input_rows, in); dlib::deserialize(item.input_cols, in); dlib::deserialize(item.output_rows, in); dlib::deserialize(item.output_cols, in); } }; /*! Represents a complete ARC-AGI task. Each task contains: - A unique task identifier - A set of training demonstration pairs - A set of test pairs (where outputs are to be predicted) !*/ struct arc_task { std::string task_id; std::vector train_pairs; std::vector test_pairs; friend void serialize(const arc_task& item, std::ostream& out) { dlib::serialize(item.task_id, out); dlib::serialize(item.train_pairs, out); dlib::serialize(item.test_pairs, out); } friend void deserialize(arc_task& item, std::istream& in) { dlib::deserialize(item.task_id, in); dlib::deserialize(item.train_pairs, in); dlib::deserialize(item.test_pairs, in); } }; // ---------------------------------------------------------------------------------------- // Internal JSON parsing utilities // ---------------------------------------------------------------------------------------- namespace internal { using raw_arc_grid_t = std::vector>; // ------------------------------------------------------------------------------------ inline std::string read_file_to_string(const std::string& path) /*! ensures - Reads the entire contents of a file and returns it as a string - Throws std::runtime_error if the file cannot be opened !*/ { std::ifstream file(path); if (!file.is_open()) throw std::runtime_error("Failed to open file: " + path); std::stringstream buffer; buffer << file.rdbuf(); return buffer.str(); } // ------------------------------------------------------------------------------------ inline std::vector parse_int_array(const std::string& str) /*! ensures - Parses a comma-separated string of integers - Returns a vector containing the parsed integers - Whitespace around numbers is automatically stripped !*/ { std::vector result; std::stringstream ss(str); std::string segment; while (std::getline(ss, segment, ',')) { segment.erase(0, segment.find_first_not_of(" \t\n\r")); segment.erase(segment.find_last_not_of(" \t\n\r") + 1); if (!segment.empty()) result.push_back(std::stoi(segment)); } return result; } // ------------------------------------------------------------------------------------ inline raw_arc_grid_t parse_arc_grid(std::string::const_iterator& it, const std::string::const_iterator& end) /*! ensures - Parses a 2D grid from JSON array-of-arrays format - Advances the iterator 'it' past the parsed content - Returns a vector of vectors representing the grid rows - Throws std::runtime_error on malformed input !*/ { raw_arc_grid_t grid; // Locate the opening bracket of the outer array it = std::find(it, end, '['); if (it == end) return grid; ++it; // Skip any leading whitespace while (it != end && std::isspace(*it)) ++it; // Verify we have an array of arrays (second '[') if (it == end || *it != '[') return grid; // Parse each row in the grid while (it != end) { // Skip whitespace between rows while (it != end && std::isspace(*it)) ++it; // Check for end of outer array if (it == end || *it == ']') break; // Expect a '[' at the start of each row if (*it != '[') { ++it; continue; } ++it; // Find the closing ']' for this row auto inner_end = std::find(it, end, ']'); if (inner_end == end) throw std::runtime_error("Missing inner array closing bracket"); // Parse the integers in this row std::string row_str(it, inner_end); auto row = parse_int_array(row_str); if (!row.empty()) grid.push_back(row); it = inner_end; ++it; // Skip trailing whitespace, commas, and newlines while (it != end && (*it == ' ' || *it == ',' || *it == '\n' || *it == '\r' || *it == '\t')) ++it; } // Advance past the closing ']' of the outer array if (it != end && *it == ']') ++it; return grid; } // ------------------------------------------------------------------------------------ inline std::string::const_iterator find_key_value_start( const std::string& content, const std::string& key, std::string::const_iterator start_it) /*! ensures - Searches for a JSON key-value pair starting from start_it - Returns an iterator pointing to the first character of the value - Returns content.end() if the key is not found !*/ { std::string search_str = "\"" + key + "\":"; auto pos = std::search(start_it, content.end(), search_str.begin(), search_str.end()); if (pos == content.end()) return content.end(); pos += search_str.length(); while (pos != content.end() && std::isspace(*pos)) ++pos; return pos; } // ------------------------------------------------------------------------------------ inline std::string extract_task_id_from_filename(const std::string& filename) /*! ensures - Extracts the task ID from a filename by removing the file extension - If no extension is found, returns the filename unchanged !*/ { size_t dot_pos = filename.find_last_of('.'); if (dot_pos == std::string::npos) return filename; return filename.substr(0, dot_pos); } } // namespace internal // ---------------------------------------------------------------------------------------- // arc_agi_manager class // ---------------------------------------------------------------------------------------- /*! The arc_agi_manager class provides functionality to: - Load ARC-AGI tasks from JSON files - Manage training and evaluation datasets - Convert grids to token sequences for LLM training - Generate training batches with sliding window context - Serialize and deserialize task data THREAD SAFETY This class is not thread-safe. External synchronization is required if accessing the same instance from multiple threads. TOKENIZATION STRATEGY Grids are tokenized row-by-row with TOKEN_ROW_END markers to preserve dimensional information. This allows the model to learn the structure of non-square grids (ranging from 1x1 to 30x30) without explicit dimension encoding. !*/ class arc_agi_manager { private: std::vector training_tasks; std::vector evaluation_tasks; std::map training_task_id_map; std::map evaluation_task_id_map; // ------------------------------------------------------------------------------------ static void append_flat_grid(std::vector& sequence, const arc_grid_t& grid) /*! requires - grid contains valid color values (0-9) ensures - Appends the grid to the sequence in row-major order - Each row is terminated with TOKEN_ROW_END - This encoding preserves grid dimensions for reconstruction !*/ { for (long r = 0; r < grid.nr(); ++r) { for (long c = 0; c < grid.nc(); ++c) sequence.push_back(static_cast(grid(r, c))); // Mark the end of this row to encode dimensional information sequence.push_back(TOKEN_ROW_END); } } // ------------------------------------------------------------------------------------ static arc_grid_t to_dlib_matrix(const internal::raw_arc_grid_t& grid) /*! requires - grid is a valid 2D array with consistent row lengths - all values are in the range [0, 9] ensures - Converts a raw vector-of-vectors grid to a dlib matrix - Returns an empty matrix if the input grid is empty throws - DLIB_CASSERT if row lengths are inconsistent - DLIB_CASSERT if pixel values are outside [0, 9] !*/ { if (grid.empty()) return arc_grid_t(0, 0); long rows = static_cast(grid.size()); long cols = static_cast(grid[0].size()); arc_grid_t mat(rows, cols); for (long r = 0; r < rows; ++r) { DLIB_CASSERT(static_cast(grid[r].size()) == cols, "Inconsistent column size in grid"); for (long c = 0; c < cols; ++c) { DLIB_CASSERT(grid[r][c] >= 0 && grid[r][c] <= 9, "Invalid pixel value (must be 0-9)"); mat(r, c) = static_cast(grid[r][c]); } } return mat; } // ------------------------------------------------------------------------------------ arc_task parse_arc_task_from_content(const std::string& content, const std::string& filename) /*! ensures - Parses a complete ARC task from JSON content - Returns an arc_task structure with all training and test pairs - Task ID is extracted from the filename throws - std::runtime_error on malformed JSON or missing required fields !*/ { arc_task task; task.task_id = internal::extract_task_id_from_filename(filename); auto parse_pairs = [&](const std::string& key, std::vector& pairs) { auto it = internal::find_key_value_start(content, key, content.begin()); if (it == content.end() || *it != '[') throw std::runtime_error("'" + key + "' array not found"); ++it; // Iterate through each object in the array while (it != content.end()) { // Skip inter-object whitespace while (it != content.end() && std::isspace(*it)) ++it; // Check if we've reached the end of the array if (it == content.end() || *it == ']') break; // Locate the opening brace of this object if (*it != '{') { ++it; continue; } // Mark boundaries for scoped key searches auto object_start = it; ++it; // Find the matching closing brace int brace_depth = 1; auto object_end = it; while (object_end != content.end() && brace_depth > 0) { if (*object_end == '{') ++brace_depth; else if (*object_end == '}') --brace_depth; ++object_end; } if (object_end == content.end()) throw std::runtime_error("Missing object closing bracket"); arc_task_pair pair; // Parse the "input" field within this object's scope auto input_it = internal::find_key_value_start(content, "input", object_start); if (input_it == content.end() || input_it >= object_end) throw std::runtime_error("'input' not found in " + key + " object"); auto raw_input = internal::parse_arc_grid(input_it, object_end); pair.input = to_dlib_matrix(raw_input); pair.input_rows = pair.input.nr(); pair.input_cols = pair.input.nc(); // Parse the "output" field (search starts after input) auto output_it = internal::find_key_value_start(content, "output", input_it); if (output_it == content.end() || output_it >= object_end) throw std::runtime_error("'output' not found in " + key + " object"); auto raw_output = internal::parse_arc_grid(output_it, object_end); pair.output = to_dlib_matrix(raw_output); pair.output_rows = pair.output.nr(); pair.output_cols = pair.output.nc(); pairs.push_back(pair); // Advance iterator past this object it = object_end; } }; parse_pairs("train", task.train_pairs); parse_pairs("test", task.test_pairs); return task; } // ------------------------------------------------------------------------------------ std::vector load_all_tasks(const std::string& directory_path, std::map& id_map) /*! ensures - Loads all .json files from the specified directory - Each file is parsed as an ARC task - Returns a vector of successfully loaded tasks - Populates id_map with task_id to index mappings - Outputs diagnostic information to stdout/stderr !*/ { std::vector tasks; std::cout << "Loading tasks from: " << directory_path << std::endl; try { const dlib::directory dir(directory_path); std::vector all_files = dir.get_files(); std::cout << "Found " << all_files.size() << " files in directory" << std::endl; // Filter for JSON files only std::vector json_files; for (const auto& file : all_files) { const std::string& filename = file.name(); if (filename.size() >= 5 && filename.substr(filename.size() - 5) == ".json") { json_files.push_back(file); } } std::cout << "Found " << json_files.size() << " .json files" << std::endl; if (json_files.empty()) { std::cout << "WARNING: No .json files found in " << directory_path << std::endl; return tasks; } size_t success_count = 0; size_t error_count = 0; // Attempt to load each JSON file for (const auto& file : json_files) { try { std::string content = internal::read_file_to_string(file.full_name()); arc_task task = parse_arc_task_from_content(content, file.name()); id_map[task.task_id] = tasks.size(); tasks.push_back(task); ++success_count; } catch (const std::exception& e) { std::cerr << "ERROR parsing " << file.name() << ": " << e.what() << std::endl; ++error_count; } } std::cout << "Successfully loaded " << success_count << " tasks" << std::endl; if (error_count > 0) { std::cout << "Failed to load " << error_count << " tasks" << std::endl; } } catch (const dlib::directory::dir_not_found& e) { std::cerr << "ERROR: Directory not found: " << directory_path << std::endl; std::cerr << "Details: " << e.info << std::endl; } catch (const dlib::directory::listing_error& e) { std::cerr << "ERROR: Cannot list directory: " << directory_path << std::endl; std::cerr << "Details: " << e.info << std::endl; } catch (const std::exception& e) { std::cerr << "ERROR during directory navigation: " << e.what() << std::endl; } return tasks; } public: arc_agi_manager() = default; // ------------------------------------------------------------------------------------ void load_data(const std::string& training_path, const std::string& evaluation_path) /*! ensures - Loads all ARC tasks from training and evaluation directories - Clears any previously loaded data - Outputs a summary of loaded tasks to stdout !*/ { training_task_id_map.clear(); evaluation_task_id_map.clear(); training_tasks = load_all_tasks(training_path, training_task_id_map); evaluation_tasks = load_all_tasks(evaluation_path, evaluation_task_id_map); std::cout << "--- ARC Data Loading Summary ---" << std::endl; std::cout << "Loaded " << training_tasks.size() << " training tasks" << std::endl; std::cout << "Loaded " << evaluation_tasks.size() << " evaluation tasks" << std::endl; std::cout << "--------------------------------" << std::endl; } // ------------------------------------------------------------------------------------ const arc_task& get_training_task(size_t index) const /*! requires - index < num_training_tasks() ensures - Returns a const reference to the training task at the given index throws - DLIB_CASSERT if index is out of bounds !*/ { DLIB_CASSERT(index < training_tasks.size(), "Training task index out of bounds" << "\n\tRequested index: " << index << "\n\tAvailable tasks: " << training_tasks.size()); return training_tasks[index]; } // ------------------------------------------------------------------------------------ const arc_task& get_evaluation_task(size_t index) const /*! requires - index < num_evaluation_tasks() ensures - Returns a const reference to the evaluation task at the given index throws - DLIB_CASSERT if index is out of bounds !*/ { DLIB_CASSERT(index < evaluation_tasks.size(), "Evaluation task index out of bounds"); return evaluation_tasks[index]; } // ------------------------------------------------------------------------------------ const arc_task& get_training_task_by_id(const std::string& task_id) const /*! ensures - Returns a const reference to the training task with the given ID throws - std::runtime_error if task_id is not found !*/ { auto it = training_task_id_map.find(task_id); if (it == training_task_id_map.end()) throw std::runtime_error("Training task ID not found: " + task_id); return training_tasks[it->second]; } // ------------------------------------------------------------------------------------ const arc_task& get_evaluation_task_by_id(const std::string& task_id) const /*! ensures - Returns a const reference to the evaluation task with the given ID throws - std::runtime_error if task_id is not found !*/ { auto it = evaluation_task_id_map.find(task_id); if (it == evaluation_task_id_map.end()) throw std::runtime_error("Evaluation task ID not found: " + task_id); return evaluation_tasks[it->second]; } // ------------------------------------------------------------------------------------ size_t num_training_tasks() const { return training_tasks.size(); } size_t num_evaluation_tasks() const { return evaluation_tasks.size(); } // ------------------------------------------------------------------------------------ void serialize(std::ostream& out) const /*! ensures - Serializes the entire dataset to the output stream - Format is versioned for forward compatibility !*/ { dlib::serialize("arc_agi_v1", out); dlib::serialize(training_tasks, out); dlib::serialize(evaluation_tasks, out); dlib::serialize(training_task_id_map, out); dlib::serialize(evaluation_task_id_map, out); } // ------------------------------------------------------------------------------------ void deserialize(std::istream& in) /*! ensures - Deserializes a dataset from the input stream - Replaces any existing data in this object throws - serialization_error if version mismatch is detected !*/ { std::string version; dlib::deserialize(version, in); if (version != "arc_agi_v1") throw serialization_error("Unexpected version in arc_agi_manager"); dlib::deserialize(training_tasks, in); dlib::deserialize(evaluation_tasks, in); dlib::deserialize(training_task_id_map, in); dlib::deserialize(evaluation_task_id_map, in); } // ---------------------------------------------------------------------------------------- // Tokenization for LLM-style training // ---------------------------------------------------------------------------------------- static arc_token_sequence_t tokenize_input_context(const arc_task& task, const arc_task_pair& test_pair) /*! ensures - Creates a token sequence representing the input context for a test pair - Format: [train_input, SEP_IO, train_output, SEP_PAIR]* QUERY_START test_input GEN_START - Each grid is tokenized with TOKEN_ROW_END markers preserving dimensions - Returns a column vector of tokens !*/ { std::vector sequence; // Encode all training demonstration pairs for (const auto& pair : task.train_pairs) { append_flat_grid(sequence, pair.input); sequence.push_back(TOKEN_SEP_IO); append_flat_grid(sequence, pair.output); sequence.push_back(TOKEN_SEP_PAIR); } // Encode the test query sequence.push_back(TOKEN_QUERY_START); append_flat_grid(sequence, test_pair.input); sequence.push_back(TOKEN_GEN_START); // Convert to dlib column vector arc_token_sequence_t result(static_cast(sequence.size())); for (long i = 0; i < static_cast(sequence.size()); ++i) result(i) = sequence[i]; return result; } // ------------------------------------------------------------------------------------ static arc_token_sequence_t tokenize_target_output(const arc_task_pair& test_pair) /*! ensures - Creates a token sequence for the target output grid - Format: output_grid END_OF_OUTPUT - Output grid includes TOKEN_ROW_END markers - Returns a column vector of tokens !*/ { std::vector sequence; append_flat_grid(sequence, test_pair.output); sequence.push_back(TOKEN_END_OF_OUTPUT); arc_token_sequence_t result(static_cast(sequence.size())); for (long i = 0; i < static_cast(sequence.size()); ++i) result(i) = sequence[i]; return result; } // ------------------------------------------------------------------------------------ static void prepare_training_data_batch( const arc_task& task, long window_len, std::vector& training_X_batch, std::vector& training_Y_batch) /*! requires - window_len > 1 ensures - Generates training samples using a sliding window approach - Each X sample contains window_len tokens of context - Each Y label is the next token following the context window - Padding tokens are used when the window extends beyond sequence boundaries - training_X_batch[i] is a column vector of length window_len - training_Y_batch[i] is the target token for training_X_batch[i] - Processes all test pairs in the task throws - DLIB_CASSERT if window_len <= 1 IMPLEMENTATION NOTES This function implements causal language modeling for ARC tasks. For each position in the concatenated [context + target] sequence, it creates a training example where: - X = [t_{pos-window_len+1}, ..., t_{pos}] - Y = t_{pos+1} The sliding window ensures the model learns to predict each token given the appropriate amount of left context. !*/ { DLIB_CASSERT(window_len > 1, "Window length must be greater than 1"); training_X_batch.clear(); training_Y_batch.clear(); for (const arc_task_pair& test_pair : task.test_pairs) { // Tokenize the full sequence: context + target arc_token_sequence_t input_context = tokenize_input_context(task, test_pair); arc_token_sequence_t target_output = tokenize_target_output(test_pair); long L_in = input_context.size(); long L_out = target_output.size(); long L_full = L_in + L_out; // Build the complete token sequence std::vector S_vec; S_vec.reserve(static_cast(L_full)); for (long i = 0; i < L_in; ++i) S_vec.push_back(input_context(i)); for (long i = 0; i < L_out; ++i) S_vec.push_back(target_output(i)); // Generate sliding window samples // For each position, create a context window of length window_len for (long pos = 0; pos < L_full; ++pos) { arc_token_sequence_t X_window(window_len); // Fill the context window // Window spans from (pos - window_len + 1) to pos inclusive for (long i = 0; i < window_len; ++i) { long context_idx = pos - window_len + 1 + i; // Use padding for positions before sequence start or after end if (context_idx < 0 || context_idx >= L_full) X_window(i) = TOKEN_PADDING; else X_window(i) = S_vec[static_cast(context_idx)]; } // The target is the next token after the window long y_token = (pos + 1 < L_full) ? S_vec[static_cast(pos + 1)] : TOKEN_PADDING; training_X_batch.push_back(std::move(X_window)); training_Y_batch.push_back(y_token); } } } // ---------------------------------------------------------------------------------------- // Detokenization utilities // ---------------------------------------------------------------------------------------- static arc_grid_t detokenize_to_grid(const arc_token_sequence_t& tokens, long start_idx = 0) /*! ensures - Reconstructs a grid from a tokenized sequence - Uses TOKEN_ROW_END markers to determine row boundaries - Stops at TOKEN_END_OF_OUTPUT, TOKEN_SEP_IO, or TOKEN_SEP_PAIR - Returns a matrix with the reconstructed grid - Returns an empty matrix if no valid grid is found throws - DLIB_CASSERT if row lengths are inconsistent IMPLEMENTATION NOTES This function recovers grid dimensions from the token stream by counting tokens between TOKEN_ROW_END markers. This allows the model to generate grids of arbitrary dimensions (1x1 to 30x30) without explicit dimension specification. !*/ { // Extract rows from the token sequence std::vector> rows; std::vector current_row; for (long i = start_idx; i < tokens.size(); ++i) { long token = tokens(i); if (token == TOKEN_ROW_END) { // End of current row - save it if non-empty if (!current_row.empty()) { rows.push_back(current_row); current_row.clear(); } } else if (token == TOKEN_END_OF_OUTPUT || token == TOKEN_SEP_IO || token == TOKEN_SEP_PAIR) { // End of grid section break; } else if (token >= COLOR_0 && token <= COLOR_9) { // Valid color token - add to current row current_row.push_back(static_cast(token)); } // Ignore other tokens (padding, etc.) } // Build the output matrix if (rows.empty()) return arc_grid_t(0, 0); long n_rows = static_cast(rows.size()); long n_cols = static_cast(rows[0].size()); arc_grid_t grid(n_rows, n_cols); for (long r = 0; r < n_rows; ++r) { DLIB_CASSERT(static_cast(rows[r].size()) == n_cols, "Inconsistent row length during detokenization" << "\n\tRow " << r << " has " << rows[r].size() << " columns" << "\n\tExpected " << n_cols << " columns"); for (long c = 0; c < n_cols; ++c) grid(r, c) = rows[r][c]; } return grid; } }; } // namespace dlib #endif // DLIB_ARC_AGI_H_ ================================================ FILE: dlib/data_io/arc_agi_abstract.h ================================================ // Copyright (C) 2025 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_ARC_AGI_ABSTRACT_H_ #ifdef DLIB_ARC_AGI_ABSTRACT_H_ #include #include #include #include "../matrix.h" #include "../serialize.h" namespace dlib { // Type aliases for ARC-AGI data structures using arc_grid_t = matrix; using arc_token_sequence_t = matrix; // Maximum sequence length for LLM-style training constexpr long ARC_MAX_SEQUENCE_LENGTH = 4096; // Token vocabulary for the Hierarchical Reasoning Model enum arc_token_id : long { COLOR_0 = 0, COLOR_1 = 1, COLOR_2 = 2, COLOR_3 = 3, COLOR_4 = 4, COLOR_5 = 5, COLOR_6 = 6, COLOR_7 = 7, COLOR_8 = 8, COLOR_9 = 9, TOKEN_SEP_IO = 10, TOKEN_SEP_PAIR = 11, TOKEN_QUERY_START = 12, TOKEN_GEN_START = 13, TOKEN_END_OF_OUTPUT = 14, TOKEN_PADDING = 15, TOKEN_ROW_END = 16 }; // Vocabulary size constants constexpr long ARC_VOCAB_SIZE_COLORS = 10; constexpr long ARC_VOCAB_SIZE_TOTAL = 17; struct arc_task_pair { /*! WHAT THIS OBJECT REPRESENTS Represents a single Input/Output example pair within an ARC task. Each pair demonstrates a transformation pattern that the model must learn. !*/ arc_grid_t input; /*! The input grid (2D matrix of color values 0-9) !*/ arc_grid_t output; /*! The corresponding output grid showing the transformed result !*/ long input_rows; long input_cols; long output_rows; long output_cols; /*! Dimensions of the input and output grids !*/ }; struct arc_task { /*! WHAT THIS OBJECT REPRESENTS Represents a complete ARC-AGI reasoning task containing: - Multiple training pairs demonstrating a pattern - One or more test pairs where the model must predict outputs !*/ std::string task_id; /*! Unique identifier extracted from the JSON filename !*/ std::vector train_pairs; /*! Training examples demonstrating the pattern to learn !*/ std::vector test_pairs; /*! Test cases where the model must predict the output !*/ }; class arc_agi_manager { /*! WHAT THIS OBJECT REPRESENTS This object provides utilities for loading, accessing, and preparing ARC-AGI (Abstraction and Reasoning Corpus for Artificial General Intelligence) dataset for training Transformer-based models such as the Hierarchical Reasoning Model (HRM). The ARC-AGI dataset consists of visual reasoning tasks where each task contains: - Training pairs: Input/Output grid examples demonstrating a pattern - Test pairs: Input grids where the model must predict the output Each grid is a 2D matrix of integers (0-9) representing colors/symbols, with maximum dimensions of 30x30. TOKENIZATION STRATEGY Grids are tokenized row-by-row with TOKEN_ROW_END markers inserted at the end of each row. This encoding preserves dimensional information implicitly, allowing the model to learn and generate grids of arbitrary dimensions (1x1 to 30x30, including non-square grids) without requiring explicit dimension specification. The dataset is available from: https://github.com/fchollet/ARC-AGI !*/ public: arc_agi_manager(); /*! ensures - Constructs an empty arc_agi_manager object !*/ void load_data( const std::string& training_path, const std::string& evaluation_path ); /*! ensures - Attempts to load the ARC-AGI dataset from the specified directories - training_path should contain JSON files for training tasks - evaluation_path should contain JSON files for evaluation tasks - Each JSON file represents one task with training and test pairs - Task IDs are extracted from filenames (without .json extension) throws - std::runtime_error if directories cannot be accessed or files cannot be parsed !*/ const arc_task& get_training_task(size_t index) const; /*! requires - index < num_training_tasks() ensures - Returns the training task at the specified index throws - std::out_of_range if index is out of bounds !*/ const arc_task& get_evaluation_task(size_t index) const; /*! requires - index < num_evaluation_tasks() ensures - Returns the evaluation task at the specified index throws - std::out_of_range if index is out of bounds !*/ const arc_task& get_training_task_by_id(const std::string& task_id) const; /*! requires - task_id is a valid task identifier ensures - Returns the training task with the specified task_id throws - std::runtime_error if task_id is not found !*/ const arc_task& get_evaluation_task_by_id(const std::string& task_id) const; /*! requires - task_id is a valid task identifier ensures - Returns the evaluation task with the specified task_id throws - std::runtime_error if task_id is not found !*/ size_t num_training_tasks() const; /*! ensures - Returns the number of loaded training tasks !*/ size_t num_evaluation_tasks() const; /*! ensures - Returns the number of loaded evaluation tasks !*/ void serialize(std::ostream& out) const; /*! ensures - Writes the entire dataset to the output stream in Dlib's serialization format - Can be saved to a .dat file for faster loading !*/ void deserialize(std::istream& in); /*! ensures - Loads the entire dataset from the input stream - Stream must contain data previously written by serialize() throws - serialization_error if data format is invalid !*/ static arc_token_sequence_t tokenize_input_context( const arc_task& task, const arc_task_pair& test_pair ); /*! ensures - Converts the task's training pairs and the specified test input into a token sequence suitable for LLM-style training - Returns a sequence: [grid_tokens..., ROW_END, SEP_IO, grid_tokens..., ROW_END, SEP_PAIR, ..., QUERY_START, test_input_tokens..., ROW_END, GEN_START] - Each grid is encoded with TOKEN_ROW_END markers at the end of each row to preserve dimensional information - This represents the context that the model uses to predict the output !*/ static arc_token_sequence_t tokenize_target_output( const arc_task_pair& test_pair ); /*! ensures - Converts the test output grid into a token sequence - Returns a sequence: [grid_tokens..., ROW_END, ..., END_OF_OUTPUT] - Each row is terminated with TOKEN_ROW_END to preserve dimensions - This represents the ground truth that the model should predict !*/ static void prepare_training_data_batch( const arc_task& task, long window_len, std::vector& training_X_batch, std::vector& training_Y_batch ); /*! requires - window_len > 1 ensures - Prepares training data in the format required by dlib::dnn::trainer using a sliding window approach for causal language modeling - For each test pair in the task, generates training samples where: * Each X sample is a context window of size window_len containing the previous window_len tokens * Each Y label is the next token that should follow the context - #training_X_batch.size() == #training_Y_batch.size() - Each training_X_batch[i] is a column vector (matrix) of size window_len x 1 - Each training_Y_batch[i] is a single token (long) representing the target to predict - Implements left-padding with TOKEN_PADDING when the context window extends before the sequence start, preserving recent context on the right side (standard for causal language models) - The concatenated sequence is: [input_context, target_output] throws - std::invalid_argument if window_len <= 1 EXAMPLE For a sequence [A, B, C, D, E] with window_len=3: X[0] = [PAD, PAD, A] => Y[0] = B X[1] = [PAD, A, B] => Y[1] = C X[2] = [A, B, C] => Y[2] = D X[3] = [B, C, D] => Y[3] = E X[4] = [C, D, E] => Y[4] = PAD !*/ static arc_grid_t detokenize_to_grid( const arc_token_sequence_t& tokens, long start_idx = 0 ); /*! requires - tokens contains a valid tokenized grid sequence with TOKEN_ROW_END markers ensures - Reconstructs a grid from a tokenized sequence - Uses TOKEN_ROW_END markers to determine row boundaries and infer grid dimensions - Parsing stops at TOKEN_END_OF_OUTPUT, TOKEN_SEP_IO, or TOKEN_SEP_PAIR - Returns a matrix containing the reconstructed grid - Returns an empty matrix (0x0) if no valid grid is found - Grid dimensions are automatically determined from the token stream: * Number of rows = count of TOKEN_ROW_END markers * Number of columns = tokens between consecutive TOKEN_ROW_END markers throws - DLIB_CASSERT if row lengths are inconsistent (indicating malformed data) EXAMPLE Input tokens: [1, 2, 3, ROW_END, 4, 5, 6, ROW_END, END_OF_OUTPUT] Returns: 2x3 grid = [[1, 2, 3], [4, 5, 6]] !*/ }; } // namespace dlib #endif // DLIB_ARC_AGI_ABSTRACT_H_ ================================================ FILE: dlib/data_io/cifar.cpp ================================================ // Copyright (C) 2020 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CIFAR_CPp_ #define DLIB_CIFAR_CPp_ #include "cifar.h" #include // ---------------------------------------------------------------------------------------- namespace dlib { namespace impl { void load_cifar_10_batch ( const std::string& folder_name, const std::string& batch_name, const size_t first_idx, const size_t images_per_batch, std::vector>& images, std::vector& labels ) { std::ifstream fin(folder_name + "/" + batch_name, std::ios::binary); if (!fin) throw error("Unable to open file " + batch_name); const long nr = 32; const long nc = 32; const long plane_size = nr * nc; const long image_size = 3 * plane_size; for (size_t i = 0; i < images_per_batch; ++i) { char l; fin.read(&l, 1); labels[first_idx + i] = l; images[first_idx + i].set_size(nr, nc); std::array buffer; fin.read((char*)(buffer.data()), buffer.size()); for (long k = 0; k < plane_size; ++k) { char r = buffer[0 * plane_size + k]; char g = buffer[1 * plane_size + k]; char b = buffer[2 * plane_size + k]; const long row = k / nr; const long col = k % nr; images[first_idx + i](row, col) = rgb_pixel(r, g, b); } } if (!fin) throw error("Unable to read file " + batch_name); if (fin.get() != EOF) throw error("Unexpected bytes at end of " + batch_name); } } void load_cifar_10_dataset ( const std::string& folder_name, std::vector>& training_images, std::vector& training_labels, std::vector>& testing_images, std::vector& testing_labels ) { using namespace std; const size_t images_per_batch = 10000; const size_t num_training_batches = 5; const size_t num_testing_batches = 1; training_images.resize(images_per_batch * num_training_batches); training_labels.resize(images_per_batch * num_training_batches); testing_images.resize(images_per_batch * num_testing_batches); testing_labels.resize(images_per_batch * num_testing_batches); std::vector training_batches_names{ "data_batch_1.bin", "data_batch_2.bin", "data_batch_3.bin", "data_batch_4.bin", "data_batch_5.bin", }; for (size_t i = 0; i < num_training_batches; ++i) { impl::load_cifar_10_batch( folder_name, training_batches_names[i], i * images_per_batch, images_per_batch, training_images, training_labels); } impl::load_cifar_10_batch( folder_name, "test_batch.bin", 0, images_per_batch, testing_images, testing_labels); } } // ---------------------------------------------------------------------------------------- #endif // DLIB_CIFAR_CPp_ ================================================ FILE: dlib/data_io/cifar.h ================================================ // Copyright (C) 2020 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CIFAR_Hh_ #define DLIB_CIFAR_Hh_ #include "cifar_abstract.h" #include #include #include "../matrix.h" #include "../pixel.h" // ---------------------------------------------------------------------------------------- namespace dlib { void load_cifar_10_dataset ( const std::string& folder_name, std::vector>& training_images, std::vector& training_labels, std::vector>& testing_images, std::vector& testing_labels ); } // ---------------------------------------------------------------------------------------- #ifdef NO_MAKEFILE #include "cifar.cpp" #endif #endif // DLIB_CIFAR_Hh_ ================================================ FILE: dlib/data_io/cifar_abstract.h ================================================ // Copyright (C) 2020 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_CIFAR_ABSTRACT_Hh_ #ifdef DLIB_CIFAR_ABSTRACT_Hh_ #include #include #include "../matrix.h" #include "../pixel.h" // ---------------------------------------------------------------------------------------- namespace dlib { void load_cifar_10_dataset ( const std::string& folder_name, std::vector>& training_images, std::vector& training_labels, std::vector>& testing_images, std::vector& testing_labels ); /*! ensures - Attempts to load the CIFAR-10 dataset from the hard drive. The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. It is available from https://www.cs.toronto.edu/~kriz/cifar.html. In particular, the 6 files comprising the CIFAR-10 dataset should be present in the folder indicated by folder_name. These six files are: - data_batch_1.bin - data_batch_2.bin - data_batch_3.bin - data_batch_4.bin - data_batch_5.bin - test_batch.bin - #training_images == The 50,000 training images from the dataset. - #training_labels == The labels for the contents of #training_images. I.e. #training_labels[i] is the label of #training_images[i]. - #testing_images == The 10,000 testing images from the dataset. - #testing_labels == The labels for the contents of #testing_images. I.e. #testing_labels[i] is the label of #testing_images[i]. throws - dlib::error if some problem prevents us from loading the data or the files can't be found. !*/ } // ---------------------------------------------------------------------------------------- #endif // DLIB_CIFAR_ABSTRACT_Hh_ ================================================ FILE: dlib/data_io/image_dataset_metadata.cpp ================================================ // Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_IMAGE_DAtASET_METADATA_CPPh_ #define DLIB_IMAGE_DAtASET_METADATA_CPPh_ #include "image_dataset_metadata.h" #include #include #include "../compress_stream.h" #include "../base64.h" #include "../xml_parser.h" #include "../string.h" // ---------------------------------------------------------------------------------------- namespace dlib { namespace image_dataset_metadata { // ------------------------------------------------------------------------------------ const std::string get_decoded_string(); void create_image_metadata_stylesheet_file(const std::string& main_filename) { std::string path; std::string::size_type pos = main_filename.find_last_of("/\\"); if (pos != std::string::npos) path = main_filename.substr(0,pos+1); std::ofstream fout((path + "image_metadata_stylesheet.xsl").c_str()); if (!fout) throw dlib::error("ERROR: Unable to open image_metadata_stylesheet.xsl for writing."); fout << get_decoded_string(); if (!fout) throw dlib::error("ERROR: Unable to write to image_metadata_stylesheet.xsl."); } void save_image_dataset_metadata ( const dataset& meta, const std::string& filename ) { create_image_metadata_stylesheet_file(filename); const std::vector& images = meta.images; std::ofstream fout(filename.c_str()); if (!fout) throw dlib::error("ERROR: Unable to open " + filename + " for writing."); fout << "\n"; fout << "\n"; fout << "\n"; fout << "" << meta.name << "\n"; fout << "" << meta.comment << "\n"; fout << "\n"; for (unsigned long i = 0; i < images.size(); ++i) { fout << " \n"; // save all the boxes for (unsigned long j = 0; j < images[i].boxes.size(); ++j) { const box& b = images[i].boxes[j]; fout << " \n"; if (b.has_label()) fout << " \n"; // save all the parts std::map::const_iterator itr; for (itr = b.parts.begin(); itr != b.parts.end(); ++itr) { fout << " \n"; } fout << " \n"; } else { fout << "/>\n"; } } fout << " \n"; if (!fout) throw dlib::error("ERROR: Unable to write to " + filename + "."); } fout << "\n"; fout << ""; } // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ class doc_handler : public document_handler { std::vector ts; image temp_image; box temp_box; dataset& meta; public: doc_handler( dataset& metadata_ ): meta(metadata_) {} virtual void start_document ( ) { meta = dataset(); ts.clear(); temp_image = image(); temp_box = box(); } virtual void end_document ( ) { } virtual void start_element ( const unsigned long line_number, const std::string& name, const dlib::attribute_list& atts ) { try { if (ts.size() == 0) { if (name != "dataset") { std::ostringstream sout; sout << "Invalid XML document. Root tag must be . Found <" << name << "> instead."; throw dlib::error(sout.str()); } else { ts.push_back(name); return; } } if (name == "box") { if (atts.is_in_list("top")) temp_box.rect.top() = sa = atts["top"]; else throw dlib::error(" missing required attribute 'top'"); if (atts.is_in_list("left")) temp_box.rect.left() = sa = atts["left"]; else throw dlib::error(" missing required attribute 'left'"); if (atts.is_in_list("width")) temp_box.rect.right() = sa = atts["width"]; else throw dlib::error(" missing required attribute 'width'"); if (atts.is_in_list("height")) temp_box.rect.bottom() = sa = atts["height"]; else throw dlib::error(" missing required attribute 'height'"); if (atts.is_in_list("difficult")) temp_box.difficult = sa = atts["difficult"]; if (atts.is_in_list("truncated")) temp_box.truncated = sa = atts["truncated"]; if (atts.is_in_list("occluded")) temp_box.occluded = sa = atts["occluded"]; if (atts.is_in_list("ignore")) temp_box.ignore = sa = atts["ignore"]; if (atts.is_in_list("angle")) temp_box.angle = sa = atts["angle"]; if (atts.is_in_list("age")) temp_box.age = sa = atts["age"]; if (atts.is_in_list("gender")) { if (atts["gender"] == "male") temp_box.gender = MALE; else if (atts["gender"] == "female") temp_box.gender = FEMALE; else if (atts["gender"] == "unknown") temp_box.gender = UNKNOWN; else throw dlib::error("Invalid gender string in box attribute."); } if (atts.is_in_list("pose")) temp_box.pose = sa = atts["pose"]; if (atts.is_in_list("detection_score")) temp_box.detection_score = sa = atts["detection_score"]; temp_box.rect.bottom() += temp_box.rect.top()-1; temp_box.rect.right() += temp_box.rect.left()-1; } else if (name == "part" && ts.back() == "box") { point temp; if (atts.is_in_list("x")) temp.x() = sa = atts["x"]; else throw dlib::error(" missing required attribute 'x'"); if (atts.is_in_list("y")) temp.y() = sa = atts["y"]; else throw dlib::error(" missing required attribute 'y'"); if (atts.is_in_list("name")) { if (temp_box.parts.count(atts["name"])==0) { temp_box.parts[atts["name"]] = temp; } else { throw dlib::error(" with name '" + atts["name"] + "' is defined more than one time in a single box."); } } else { throw dlib::error(" missing required attribute 'name'"); } } else if (name == "image") { temp_image.boxes.clear(); if (atts.is_in_list("file")) temp_image.filename = atts["file"]; else throw dlib::error(" missing required attribute 'file'"); if (atts.is_in_list("width")) temp_image.width = sa = atts["width"]; if (atts.is_in_list("height")) temp_image.height = sa = atts["height"]; } ts.push_back(name); } catch (error& e) { throw dlib::error("Error on line " + cast_to_string(line_number) + ": " + e.what()); } } virtual void end_element ( const unsigned long , const std::string& name ) { ts.pop_back(); if (ts.size() == 0) return; if (name == "box" && ts.back() == "image") { temp_image.boxes.push_back(temp_box); temp_box = box(); } else if (name == "image" && ts.back() == "images") { meta.images.push_back(temp_image); temp_image = image(); } } virtual void characters ( const std::string& data ) { if (ts.size() == 2 && ts[1] == "name") { meta.name = trim(data); } else if (ts.size() == 2 && ts[1] == "comment") { meta.comment = trim(data); } else if (ts.size() >= 2 && ts[ts.size()-1] == "label" && ts[ts.size()-2] == "box") { temp_box.label = trim(data); } } virtual void processing_instruction ( const unsigned long , const std::string& , const std::string& ) { } }; // ---------------------------------------------------------------------------------------- class xml_error_handler : public error_handler { public: virtual void error ( const unsigned long ) { } virtual void fatal_error ( const unsigned long line_number ) { std::ostringstream sout; sout << "There is a fatal error on line " << line_number << " so parsing will now halt."; throw dlib::error(sout.str()); } }; // ------------------------------------------------------------------------------------ void load_image_dataset_metadata ( dataset& meta, const std::string& filename ) { xml_error_handler eh; doc_handler dh(meta); std::ifstream fin(filename.c_str()); if (!fin) throw dlib::error("ERROR: unable to open " + filename + " for reading."); xml_parser parser; parser.add_document_handler(dh); parser.add_error_handler(eh); parser.parse(fin); } // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ // This function returns the contents of the file 'images.xsl' const std::string get_decoded_string() { dlib::base64 base64_coder; dlib::compress_stream::kernel_1ea compressor; std::ostringstream sout; std::istringstream sin; // The base64 encoded data from the file 'image_metadata_stylesheet.xsl' we want to decode and return. sout << "PFWfgmWfCHr1DkV63lbjjeY2dCc2FbHDOVh0Kd7dkvaOfRYrOG24f0x77/5iMVq8FtE3UBxtGwSd"; sout << "1ZHOHRSHgieNoeBv8ssJQ75RRxYtFKRY3OTPX5eKQoCN9jUaUnHnR4QZtEHgmKqXSs50Yrdd+2Ah"; sout << "gNyarPZCiR6nvqNvCjtP2MP5FxleqNf8Fylatm2KdsXmrv5K87LYVN7i7JMkmZ++cTXYSOxDmxZi"; sout << "OiCH8funXUdF9apDW547gCjz9HOQUI6dkz5dYUeFjfp6dFugpnaJyyprFLKq048Qk7+QiL4CNF/G"; sout << "7e0VpBw8dMpiyRNi2fSQGSZGfIAUQKKT6+rPwQoRH2spdjsdXVWj4XQAqBX87nmqMnqjMhn/Vd1s"; sout << "W5aoC0drwRGu3Xe3gn9vBL8hBkRXcJvEy6q/lb9bYnsLemhE5Zp/+nTmTBjfT9UFYLcsmgsjC+4n"; sout << "Bq6h9QlpuyMYqJ8RvW8pp3mFlvXc3Yg+18t5F0hSMQfaIFYAuDPU2lVzPpY+ba0B39iu9IrPCLsS"; sout << "+tUtSNSmQ74CtzZgKKjkTMA3nwYP2SDmZE3firq42pihT7hdU5vYkes69K8AQl8WZyLPpMww+r0z"; sout << "+veEHPlAuxF7kL3ZvVjdB+xABwwqDe0kSRHRZINYdUfJwJdfYLyDnYoMjj6afqIJZ7QOBPZ42tV5"; sout << "3hYOQTFwTNovOastzJJXQe1kxPg1AQ8ynmfjjJZqD0xKedlyeJybP919mVAA23UryHsq9TVlabou"; sout << "qNl3xZW/mKKktvVsd/nuH62HIv/kgomyhaEUY5HgupupBUbQFZfyljZ5bl3g3V3Y1400Z1xTM/LL"; sout << "LJpeLdlqoGzIe/19vAN1zUUVId9F/OLNUl3Zoar63yZERSJHcsuq/Pasisp0HIGi7rfI9EIQF7C/"; sout << "IhLKLZsJ+LOycreQGOJALZIEZHOqxYLSXG0qaPM5bQL/MQJ2OZfwEhQgYOrjaM7oPOHHEfTq5kcO"; sout << "daMwzefKfxrF2GXbUs0bYsEXsIGwENIUKMliFaAI4qKLxxb94oc+O3BRjWueZjZty2zKawQyTHNd"; sout << "ltFJBUzfffdZN9Wq4zbPzntkM3U6Ys4LRztx5M15dtbhFeKx5rAf2tPXT6wU01hx7EJxBJzpvoDE"; sout << "YwEoYVDSYulRKpgk82cHFzzUDgWXbl4paFSe1L1w8r9KHr67SYJDTUG86Lrm6LJ0rw73Xp0NAFcU"; sout << "MKpiG9g1cHW74HYbUb/yAbtVWt40eB7M637umdo2jWz/r/vP5WnfSMXEbkyWebsa1fFceg/TLWy6"; sout << "E8OTc4XKB48h1oFIlGagOiprxho3+F3TIcxDSwA="; // Put the data into the istream sin sin.str(sout.str()); sout.str(""); // Decode the base64 text into its compressed binary form base64_coder.decode(sin,sout); sin.clear(); sin.str(sout.str()); sout.str(""); // Decompress the data into its original form compressor.decompress(sin,sout); // Return the decoded and decompressed data return sout.str(); } } } // ---------------------------------------------------------------------------------------- #endif // DLIB_IMAGE_DAtASET_METADATA_CPPh_ ================================================ FILE: dlib/data_io/image_dataset_metadata.h ================================================ // Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_IMAGE_DAtASET_METADATA_Hh_ #define DLIB_IMAGE_DAtASET_METADATA_Hh_ #include #include #include "../geometry.h" // ---------------------------------------------------------------------------------------- namespace dlib { namespace image_dataset_metadata { // ------------------------------------------------------------------------------------ enum gender_t { UNKNOWN, MALE, FEMALE }; // ------------------------------------------------------------------------------------ struct box { /*! WHAT THIS OBJECT REPRESENTS This object represents an annotated rectangular area of an image. It is typically used to mark the location of an object such as a person, car, etc. The main variable of interest is rect. It gives the location of the box. All the other variables are optional. !*/ box( ) : difficult(false), truncated(false), occluded(false), ignore(false), pose(0), detection_score(0), angle(0), gender(UNKNOWN), age(0) {} box ( const rectangle& rect_ ) : rect(rect_), difficult(false), truncated(false), occluded(false), ignore(false), pose(0), detection_score(0), angle(0), gender(UNKNOWN), age(0) {} rectangle rect; std::map parts; // optional fields std::string label; bool difficult; bool truncated; bool occluded; bool ignore; double pose; double detection_score; // The angle of the object in radians. Positive values indicate that the // object at the center of the box is rotated clockwise by angle radians. A // value of 0 would indicate that the object is in its "standard" upright pose. // Therefore, to make the object appear upright we would have to rotate the // image counter-clockwise by angle radians. double angle; gender_t gender; double age; bool has_label() const { return label.size() != 0; } /*! ensures - returns true if label metadata is present and false otherwise. !*/ }; // ------------------------------------------------------------------------------------ struct image { /*! WHAT THIS OBJECT REPRESENTS This object represents an annotated image. !*/ image() {} image(const std::string& f) : filename(f) {} std::string filename; std::vector boxes; long width = 0; long height = 0; }; // ------------------------------------------------------------------------------------ struct dataset { /*! WHAT THIS OBJECT REPRESENTS This object represents a labeled set of images. In particular, it contains the filename for each image as well as annotated boxes. !*/ std::vector images; std::string comment; std::string name; }; // ------------------------------------------------------------------------------------ void save_image_dataset_metadata ( const dataset& meta, const std::string& filename ); /*! ensures - Writes the contents of the meta object to a file with the given filename. The file will be in an XML format. throws - dlib::error This exception is thrown if there is an error which prevents this function from succeeding. !*/ // ------------------------------------------------------------------------------------ void load_image_dataset_metadata ( dataset& meta, const std::string& filename ); /*! ensures - Attempts to interpret filename as a file containing XML formatted data as produced by the save_image_dataset_metadata() function. Then meta is loaded with the contents of the file. throws - dlib::error This exception is thrown if there is an error which prevents this function from succeeding. !*/ // ------------------------------------------------------------------------------------ } } // ---------------------------------------------------------------------------------------- #ifdef NO_MAKEFILE #include "image_dataset_metadata.cpp" #endif #endif // DLIB_IMAGE_DAtASET_METADATA_Hh_ ================================================ FILE: dlib/data_io/libsvm_io.h ================================================ // Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_LIBSVM_iO_Hh_ #define DLIB_LIBSVM_iO_Hh_ #include "libsvm_io_abstract.h" #include #include #include #include "../algs.h" #include "../matrix.h" #include "../string.h" #include "../svm/sparse_vector.h" #include namespace dlib { struct sample_data_io_error : public error { sample_data_io_error(const std::string& message): error(message) {} }; // ---------------------------------------------------------------------------------------- template void load_libsvm_formatted_data ( const std::string& file_name, std::vector& samples, std::vector& labels ) { typedef typename sample_type::value_type pair_type; typedef typename basic_type::type key_type; typedef typename pair_type::second_type value_type; // You must use unsigned integral key types in your sparse vectors COMPILE_TIME_ASSERT(is_unsigned_type::value); samples.clear(); labels.clear(); std::ifstream fin(file_name.c_str()); if (!fin) throw sample_data_io_error("Unable to open file " + file_name); std::string line; std::istringstream sin; key_type key; value_type value; label_type label; sample_type sample; long line_num = 0; while (fin.peek() != EOF) { ++line_num; std::getline(fin, line); std::string::size_type pos = line.find_first_not_of(" \t\r\n"); // ignore empty lines or comment lines if (pos == std::string::npos || line[pos] == '#') continue; sin.clear(); sin.str(line); sample.clear(); sin >> label; if (!sin) throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name ); // eat whitespace sin >> std::ws; while (sin.peek() != EOF && sin.peek() != '#') { sin >> key >> std::ws; // ignore what should be a : character if (sin.get() != ':') throw sample_data_io_error("On line: " + cast_to_string(line_num) + ", error while reading file " + file_name); sin >> value; if (sin && value != 0) { sample.insert(sample.end(), std::make_pair(key, value)); } sin >> std::ws; } samples.push_back(sample); labels.push_back(label); } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template typename enable_if >::type fix_nonzero_indexing ( std::vector& samples ) { typedef typename sample_type::value_type pair_type; typedef typename basic_type::type key_type; if (samples.size() == 0) return; // figure out the min index value key_type min_idx = samples[0].begin()->first; for (unsigned long i = 0; i < samples.size(); ++i) min_idx = std::min(min_idx, samples[i].begin()->first); // Now adjust all the samples so that their min index value is zero. if (min_idx != 0) { sample_type temp; for (unsigned long i = 0; i < samples.size(); ++i) { // copy samples[i] into temp but make sure it has a min index of zero. temp.clear(); typename sample_type::iterator j; for (j = samples[i].begin(); j != samples[i].end(); ++j) { temp.insert(temp.end(), std::make_pair(j->first-min_idx, j->second)); } // replace the current sample with temp. samples[i].swap(temp); } } } // ---------------------------------------------------------------------------------------- // If the "first" values in the std::pair objects are not const then we can modify them // directly and that is what this version of fix_nonzero_indexing() does. template typename disable_if >::type fix_nonzero_indexing ( std::vector& samples ) { typedef typename sample_type::value_type pair_type; typedef typename basic_type::type key_type; if (samples.size() == 0) return; // figure out the min index value key_type min_idx = samples[0].begin()->first; for (unsigned long i = 0; i < samples.size(); ++i) min_idx = std::min(min_idx, samples[i].begin()->first); // Now adjust all the samples so that their min index value is zero. if (min_idx != 0) { for (unsigned long i = 0; i < samples.size(); ++i) { typename sample_type::iterator j; for (j = samples[i].begin(); j != samples[i].end(); ++j) { j->first -= min_idx; } } } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // This is an overload for sparse vectors template typename disable_if,void>::type save_libsvm_formatted_data ( const std::string& file_name, const std::vector& samples, const std::vector& labels ) { typedef typename sample_type::value_type pair_type; typedef typename basic_type::type key_type; // You must use unsigned integral key types in your sparse vectors COMPILE_TIME_ASSERT(is_unsigned_type::value); // make sure requires clause is not broken DLIB_ASSERT(samples.size() == labels.size(), "\t void save_libsvm_formatted_data()" << "\n\t You have to have labels for each sample and vice versa" << "\n\t samples.size(): " << samples.size() << "\n\t labels.size(): " << labels.size() ); std::ofstream fout(file_name.c_str()); fout.precision(14); if (!fout) throw sample_data_io_error("Unable to open file " + file_name); for (unsigned long i = 0; i < samples.size(); ++i) { fout << labels[i]; for (typename sample_type::const_iterator j = samples[i].begin(); j != samples[i].end(); ++j) { if (j->second != 0) fout << " " << j->first << ":" << j->second; } fout << "\n"; if (!fout) throw sample_data_io_error("Error while writing to file " + file_name); } } // ---------------------------------------------------------------------------------------- // This is an overload for dense vectors template typename enable_if,void>::type save_libsvm_formatted_data ( const std::string& file_name, const std::vector& samples, const std::vector& labels ) { // make sure requires clause is not broken DLIB_ASSERT(samples.size() == labels.size(), "\t void save_libsvm_formatted_data()" << "\n\t You have to have labels for each sample and vice versa" << "\n\t samples.size(): " << samples.size() << "\n\t labels.size(): " << labels.size() ); std::ofstream fout(file_name.c_str()); fout.precision(14); if (!fout) throw sample_data_io_error("Unable to open file " + file_name); for (unsigned long i = 0; i < samples.size(); ++i) { fout << labels[i]; for (long j = 0; j < samples[i].size(); ++j) { if (samples[i](j) != 0) fout << " " << j << ":" << samples[i](j); } fout << "\n"; if (!fout) throw sample_data_io_error("Error while writing to file " + file_name); } } // ---------------------------------------------------------------------------------------- } #endif // DLIB_LIBSVM_iO_Hh_ ================================================ FILE: dlib/data_io/libsvm_io_abstract.h ================================================ // Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_LIBSVM_iO_ABSTRACT_Hh_ #ifdef DLIB_LIBSVM_iO_ABSTRACT_Hh_ #include #include #include #include "../algs.h" #include "../matrix.h" #include namespace dlib { struct sample_data_io_error : public error { /*! This is the exception class used by the file IO functions defined below. !*/ }; // ---------------------------------------------------------------------------------------- template < typename sample_type, typename label_type, typename alloc1, typename alloc2 > void load_libsvm_formatted_data ( const std::string& file_name, std::vector& samples, std::vector& labels ); /*! requires - sample_type must be an STL container - sample_type::value_type == std::pair where T is some kind of unsigned integral type ensures - attempts to read a file of the given name that should contain libsvm formatted data. We turn the data into sparse vectors and store it in samples - #labels.size() == #samples.size() - for all valid i: #labels[i] is the label for #samples[i] throws - sample_data_io_error This exception is thrown if there is any problem loading data from file !*/ // ---------------------------------------------------------------------------------------- template < typename sample_type, typename label_type, typename alloc1, typename alloc2 > void save_libsvm_formatted_data ( const std::string& file_name, const std::vector& samples, const std::vector& labels ); /*! requires - sample_type must be an STL container - sample_type::value_type == std::pair where T is some kind of unsigned integral type - samples.size() == labels.size() ensures - saves the data to the given file in libsvm format throws - sample_data_io_error This exception is thrown if there is any problem saving data to file !*/ // ---------------------------------------------------------------------------------------- template < typename sample_type, typename label_type, typename alloc1, typename alloc2 > void save_libsvm_formatted_data ( const std::string& file_name, const std::vector& samples, const std::vector& labels ); /*! requires - sample_type == a dense matrix (i.e. dlib::matrix) - for all valid i: is_vector(samples[i]) == true - samples.size() == labels.size() ensures - saves the data to the given file in libsvm format throws - sample_data_io_error This exception is thrown if there is any problem saving data to file !*/ // ---------------------------------------------------------------------------------------- template void fix_nonzero_indexing ( std::vector& samples ); /*! requires - samples must only contain valid sparse vectors. The definition of a sparse vector can be found at the top of dlib/svm/sparse_vector_abstract.h ensures - Adjusts the sparse vectors in samples so that they are zero-indexed. Or in other words, assume the smallest used index value in any of the sparse vectors is N. Then this function subtracts N from all the index values in samples. This is useful, for example, if you load a libsvm formatted datafile with features indexed from 1 rather than 0 and you would like to fix this. !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_LIBSVM_iO_ABSTRACT_Hh_ ================================================ FILE: dlib/data_io/load_image_dataset.h ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_LOAD_IMAGE_DaTASET_Hh_ #define DLIB_LOAD_IMAGE_DaTASET_Hh_ #include "load_image_dataset_abstract.h" #include "../misc_api.h" #include "../dir_nav.h" #include "../image_io.h" #include "../array.h" #include #include "../geometry.h" #include "image_dataset_metadata.h" #include #include #include "../image_processing/full_object_detection.h" #include #include #include "../image_transforms/image_pyramid.h" namespace dlib { // ---------------------------------------------------------------------------------------- class image_dataset_file { public: image_dataset_file(const std::string& filename) { _skip_empty_images = false; _have_parts = false; _filename = filename; _box_area_thresh = std::numeric_limits::infinity(); } image_dataset_file boxes_match_label( const std::string& label ) const { image_dataset_file temp(*this); temp._labels.insert(label); return temp; } image_dataset_file skip_empty_images( ) const { image_dataset_file temp(*this); temp._skip_empty_images = true; return temp; } image_dataset_file boxes_have_parts( ) const { image_dataset_file temp(*this); temp._have_parts = true; return temp; } image_dataset_file shrink_big_images( double new_box_area_thresh = 150*150 ) const { image_dataset_file temp(*this); temp._box_area_thresh = new_box_area_thresh; return temp; } bool should_load_box ( const image_dataset_metadata::box& box ) const { if (_have_parts && box.parts.size() == 0) return false; if (_labels.size() == 0) return true; if (_labels.count(box.label) != 0) return true; return false; } const std::string& get_filename() const { return _filename; } bool should_skip_empty_images() const { return _skip_empty_images; } bool should_boxes_have_parts() const { return _have_parts; } double box_area_thresh() const { return _box_area_thresh; } const std::set& get_selected_box_labels() const { return _labels; } private: std::string _filename; std::set _labels; bool _skip_empty_images; bool _have_parts; double _box_area_thresh; }; // ---------------------------------------------------------------------------------------- template < typename array_type > std::vector > load_image_dataset ( array_type& images, std::vector >& object_locations, const image_dataset_file& source ) { images.clear(); object_locations.clear(); std::vector > ignored_rects; using namespace dlib::image_dataset_metadata; dataset data; load_image_dataset_metadata(data, source.get_filename()); // Set the current directory to be the one that contains the // metadata file. We do this because the file might contain // file paths which are relative to this folder. locally_change_current_dir chdir(get_parent_directory(file(source.get_filename()))); typedef typename array_type::value_type image_type; image_type img; std::vector rects, ignored; for (unsigned long i = 0; i < data.images.size(); ++i) { double min_rect_size = std::numeric_limits::infinity(); rects.clear(); ignored.clear(); for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) { if (source.should_load_box(data.images[i].boxes[j])) { if (data.images[i].boxes[j].ignore) { ignored.push_back(data.images[i].boxes[j].rect); } else { rects.push_back(data.images[i].boxes[j].rect); min_rect_size = std::min(min_rect_size, rects.back().area()); } } } if (!source.should_skip_empty_images() || rects.size() != 0) { load_image(img, data.images[i].filename); if (rects.size() != 0) { // if shrinking the image would still result in the smallest box being // bigger than the box area threshold then shrink the image. while(min_rect_size/2/2 > source.box_area_thresh()) { pyramid_down<2> pyr; pyr(img); min_rect_size *= (1.0/2.0)*(1.0/2.0); for (auto&& r : rects) r = pyr.rect_down(r); for (auto&& r : ignored) r = pyr.rect_down(r); } while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh()) { pyramid_down<3> pyr; pyr(img); min_rect_size *= (2.0/3.0)*(2.0/3.0); for (auto&& r : rects) r = pyr.rect_down(r); for (auto&& r : ignored) r = pyr.rect_down(r); } } images.push_back(img); object_locations.push_back(rects); ignored_rects.push_back(ignored); } } return ignored_rects; } // ---------------------------------------------------------------------------------------- namespace impl { inline size_t num_non_ignored_boxes (const std::vector& rects) { size_t cnt = 0; for (auto& b : rects) { if (!b.ignore) cnt++; } return cnt; } } template < typename array_type > void load_image_dataset ( array_type& images, std::vector >& object_locations, const image_dataset_file& source ) { images.clear(); object_locations.clear(); using namespace dlib::image_dataset_metadata; dataset data; load_image_dataset_metadata(data, source.get_filename()); // Set the current directory to be the one that contains the // metadata file. We do this because the file might contain // file paths which are relative to this folder. locally_change_current_dir chdir(get_parent_directory(file(source.get_filename()))); typedef typename array_type::value_type image_type; image_type img; std::vector rects; for (unsigned long i = 0; i < data.images.size(); ++i) { double min_rect_size = std::numeric_limits::infinity(); rects.clear(); for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) { if (source.should_load_box(data.images[i].boxes[j])) { if (data.images[i].boxes[j].ignore) { rects.push_back(ignored_mmod_rect(data.images[i].boxes[j].rect)); } else { rects.push_back(mmod_rect(data.images[i].boxes[j].rect)); min_rect_size = std::min(min_rect_size, rects.back().rect.area()); } rects.back().label = data.images[i].boxes[j].label; } } if (!source.should_skip_empty_images() || impl::num_non_ignored_boxes(rects) != 0) { load_image(img, data.images[i].filename); if (rects.size() != 0) { // if shrinking the image would still result in the smallest box being // bigger than the box area threshold then shrink the image. while(min_rect_size/2/2 > source.box_area_thresh()) { pyramid_down<2> pyr; pyr(img); min_rect_size *= (1.0/2.0)*(1.0/2.0); for (auto&& r : rects) r.rect = pyr.rect_down(r.rect); } while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh()) { pyramid_down<3> pyr; pyr(img); min_rect_size *= (2.0/3.0)*(2.0/3.0); for (auto&& r : rects) r.rect = pyr.rect_down(r.rect); } } images.push_back(std::move(img)); object_locations.push_back(std::move(rects)); } } } // ---------------------------------------------------------------------------------------- // ******* THIS FUNCTION IS DEPRECATED, you should use another version of load_image_dataset() ******* template < typename image_type, typename MM > std::vector > load_image_dataset ( array& images, std::vector >& object_locations, const std::string& filename, const std::string& label, bool skip_empty_images = false ) { image_dataset_file f(filename); if (label.size() != 0) f = f.boxes_match_label(label); if (skip_empty_images) f = f.skip_empty_images(); return load_image_dataset(images, object_locations, f); } // ---------------------------------------------------------------------------------------- template < typename array_type > std::vector > load_image_dataset ( array_type& images, std::vector >& object_locations, const std::string& filename ) { return load_image_dataset(images, object_locations, image_dataset_file(filename)); } // ---------------------------------------------------------------------------------------- template < typename array_type > void load_image_dataset ( array_type& images, std::vector>& object_locations, const std::string& filename ) { load_image_dataset(images, object_locations, image_dataset_file(filename)); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename array_type > std::vector > load_image_dataset ( array_type& images, std::vector >& object_locations, const image_dataset_file& source, std::vector& parts_list ) { typedef typename array_type::value_type image_type; parts_list.clear(); images.clear(); object_locations.clear(); using namespace dlib::image_dataset_metadata; dataset data; load_image_dataset_metadata(data, source.get_filename()); // Set the current directory to be the one that contains the // metadata file. We do this because the file might contain // file paths which are relative to this folder. locally_change_current_dir chdir(get_parent_directory(file(source.get_filename()))); std::set all_parts; // find out what parts are being used in the dataset. Store results in all_parts. for (unsigned long i = 0; i < data.images.size(); ++i) { for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) { if (source.should_load_box(data.images[i].boxes[j])) { const std::map& parts = data.images[i].boxes[j].parts; std::map::const_iterator itr; for (itr = parts.begin(); itr != parts.end(); ++itr) { all_parts.insert(itr->first); } } } } // make a mapping between part names and the integers [0, all_parts.size()) std::map parts_idx; for (std::set::iterator i = all_parts.begin(); i != all_parts.end(); ++i) { parts_idx[*i] = parts_list.size(); parts_list.push_back(*i); } std::vector > ignored_rects; std::vector ignored; image_type img; std::vector object_dets; for (unsigned long i = 0; i < data.images.size(); ++i) { double min_rect_size = std::numeric_limits::infinity(); object_dets.clear(); ignored.clear(); for (unsigned long j = 0; j < data.images[i].boxes.size(); ++j) { if (source.should_load_box(data.images[i].boxes[j])) { if (data.images[i].boxes[j].ignore) { ignored.push_back(data.images[i].boxes[j].rect); } else { std::vector partlist(parts_idx.size(), OBJECT_PART_NOT_PRESENT); // populate partlist with all the parts present in this box. const std::map& parts = data.images[i].boxes[j].parts; std::map::const_iterator itr; for (itr = parts.begin(); itr != parts.end(); ++itr) { partlist[parts_idx[itr->first]] = itr->second; } object_dets.push_back(full_object_detection(data.images[i].boxes[j].rect, partlist)); min_rect_size = std::min(min_rect_size, object_dets.back().get_rect().area()); } } } if (!source.should_skip_empty_images() || object_dets.size() != 0) { load_image(img, data.images[i].filename); if (object_dets.size() != 0) { // if shrinking the image would still result in the smallest box being // bigger than the box area threshold then shrink the image. while(min_rect_size/2/2 > source.box_area_thresh()) { pyramid_down<2> pyr; pyr(img); min_rect_size *= (1.0/2.0)*(1.0/2.0); for (auto&& r : object_dets) { r.get_rect() = pyr.rect_down(r.get_rect()); for (unsigned long k = 0; k < r.num_parts(); ++k) r.part(k) = pyr.point_down(r.part(k)); } for (auto&& r : ignored) { r = pyr.rect_down(r); } } while(min_rect_size*(2.0/3.0)*(2.0/3.0) > source.box_area_thresh()) { pyramid_down<3> pyr; pyr(img); min_rect_size *= (2.0/3.0)*(2.0/3.0); for (auto&& r : object_dets) { r.get_rect() = pyr.rect_down(r.get_rect()); for (unsigned long k = 0; k < r.num_parts(); ++k) r.part(k) = pyr.point_down(r.part(k)); } for (auto&& r : ignored) { r = pyr.rect_down(r); } } } images.push_back(img); object_locations.push_back(object_dets); ignored_rects.push_back(ignored); } } return ignored_rects; } // ---------------------------------------------------------------------------------------- template < typename array_type > std::vector > load_image_dataset ( array_type& images, std::vector >& object_locations, const image_dataset_file& source ) { std::vector parts_list; return load_image_dataset(images, object_locations, source, parts_list); } // ---------------------------------------------------------------------------------------- template < typename array_type > std::vector > load_image_dataset ( array_type& images, std::vector >& object_locations, const std::string& filename ) { std::vector parts_list; return load_image_dataset(images, object_locations, image_dataset_file(filename), parts_list); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_LOAD_IMAGE_DaTASET_Hh_ ================================================ FILE: dlib/data_io/load_image_dataset_abstract.h ================================================ // Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_LOAD_IMAGE_DaTASET_ABSTRACT_Hh_ #ifdef DLIB_LOAD_IMAGE_DaTASET_ABSTRACT_Hh_ #include "image_dataset_metadata.h" #include "../array/array_kernel_abstract.h" #include #include #include "../image_processing/full_object_detection_abstract.h" namespace dlib { // ---------------------------------------------------------------------------------------- class image_dataset_file { /*! WHAT THIS OBJECT REPRESENTS This object is a tool used to tell the load_image_dataset() functions which boxes and images to load from an XML based image dataset file. By default, this object tells load_image_dataset() to load all images and object boxes. !*/ public: image_dataset_file( const std::string& filename ); /*! ensures - #get_filename() == filename - #should_skip_empty_images() == false - #get_selected_box_labels().size() == 0 This means that, initially, all boxes will be loaded. Therefore, for all possible boxes B we have: - #should_load_box(B) == true - #box_area_thresh() == infinity !*/ const std::string& get_filename( ) const; /*! ensures - returns the name of the XML image dataset metadata file given to this object's constructor. !*/ bool should_skip_empty_images( ) const; /*! ensures - returns true if we are supposed to skip images that don't have any non-ignored boxes to load when loading an image dataset using load_image_dataset(). !*/ image_dataset_file boxes_match_label( const std::string& label ) const; /*! ensures - returns a copy of *this that is identical in all respects to *this except that label will be included in the labels set (i.e. the set returned by get_selected_box_labels()). !*/ const std::set& get_selected_box_labels( ) const; /*! ensures - returns the set of box labels currently selected by the should_load_box() method. Note that if the set is empty then we select all boxes. !*/ image_dataset_file skip_empty_images( ) const; /*! ensures - returns a copy of *this that is identical in all respects to *this except that #should_skip_empty_images() == true. !*/ bool should_boxes_have_parts( ) const; /*! ensures - returns true if boxes must have some parts defined for them to be loaded. !*/ image_dataset_file boxes_have_parts( ) const; /*! ensures - returns a copy of *this that is identical in all respects to *this except that #should_boxes_have_parts() == true. !*/ bool should_load_box ( const image_dataset_metadata::box& box ) const; /*! ensures - returns true if we are supposed to load the given box from an image dataset XML file. In particular, if should_load_box() returns false then the load_image_dataset() routines will not return the box at all, neither in the ignore rectangles list or in the primary object_locations vector. The behavior of this function is defined as follows: - if (should_boxes_have_parts() && boxes.parts.size() == 0) then - returns false - else if (get_selected_box_labels().size() == 0) then - returns true - else if (get_selected_box_labels().count(box.label) != 0) then - returns true - else - returns false !*/ image_dataset_file shrink_big_images( double new_box_area_thresh = 150*150 ) const; /*! ensures - returns a copy of *this that is identical in all respects to *this except that #box_area_thresh() == new_box_area_thresh !*/ double box_area_thresh( ) const; /*! ensures - If the smallest non-ignored rectangle in an image has an area greater than box_area_thresh() then we will shrink the image until the area of the box is about equal to box_area_thresh(). This is useful if you have a dataset containing very high resolution images and you don't want to load it in its native high resolution. Setting the box_area_thresh() allows you to control the resolution of the loaded images. !*/ }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename array_type > std::vector > load_image_dataset ( array_type& images, std::vector >& object_locations, const image_dataset_file& source ); /*! requires - array_type == An array of images. This is anything with an interface that looks like std::vector where a "generic image" is anything that implements the generic image interface defined in dlib/image_processing/generic_image.h. ensures - This routine loads the images and their associated object boxes from the image metadata file indicated by source.get_filename(). This metadata file should be in the XML format used by the save_image_dataset_metadata() routine. - #images.size() == The number of images loaded from the metadata file. This is all the images listed in the file unless source.should_skip_empty_images() is set to true. - #images.size() == #object_locations.size() - This routine is capable of loading any image format which can be read by the load_image() routine. - let IGNORED_RECTS denote the vector returned from this function. - IGNORED_RECTS.size() == #object_locations.size() - IGNORED_RECTS == a list of the rectangles which have the "ignore" flag set to true in the input XML file. - for all valid i: - #images[i] == a copy of the i-th image from the dataset. - #object_locations[i] == a vector of all the rectangles associated with #images[i]. These are the rectangles for which source.should_load_box() returns true and are also not marked as "ignore" in the XML file. - IGNORED_RECTS[i] == A vector of all the rectangles associated with #images[i] that are marked as "ignore" but not discarded by source.should_load_box(). - if (source.should_skip_empty_images() == true) then - #object_locations[i].size() != 0 (i.e. we won't load images that don't end up having any object locations) !*/ // ---------------------------------------------------------------------------------------- template < typename array_type > std::vector > load_image_dataset ( array_type& images, std::vector >& object_locations, const std::string& filename ); /*! requires - array_type == An array of images. This is anything with an interface that looks like std::vector where a "generic image" is anything that implements the generic image interface defined in dlib/image_processing/generic_image.h. ensures - performs: return load_image_dataset(images, object_locations, image_dataset_file(filename)); (i.e. it ignores box labels and therefore loads all the boxes in the dataset) !*/ // ---------------------------------------------------------------------------------------- template < typename array_type > void load_image_dataset ( array_type& images, std::vector >& object_locations, const image_dataset_file& source ); /*! requires - array_type == An array of images. This is anything with an interface that looks like std::vector where a "generic image" is anything that implements the generic image interface defined in dlib/image_processing/generic_image.h. ensures - This function has essentially the same behavior as the above load_image_dataset() routines, except here we output to a vector of mmod_rects instead of rectangles. In this case, both ignore and non-ignore rectangles go into object_locations since mmod_rect has an ignore boolean field that records the ignored/non-ignored state of each rectangle. We also store a each box's string label into the mmod_rect::label field as well. !*/ // ---------------------------------------------------------------------------------------- template < typename array_type > void load_image_dataset ( array_type& images, std::vector >& object_locations, const std::string& filename ); /*! requires - array_type == An array of images. This is anything with an interface that looks like std::vector where a "generic image" is anything that implements the generic image interface defined in dlib/image_processing/generic_image.h. ensures - performs: load_image_dataset(images, object_locations, image_dataset_file(filename)); (i.e. it ignores box labels and therefore loads all the boxes in the dataset) !*/ // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename array_type > std::vector > load_image_dataset ( array_type& images, std::vector >& object_locations, const image_dataset_file& source, std::vector& parts_list ); /*! requires - array_type == An array of images. This is anything with an interface that looks like std::vector where a "generic image" is anything that implements the generic image interface defined in dlib/image_processing/generic_image.h. ensures - This routine loads the images and their associated object locations from the image metadata file indicated by source.get_filename(). This metadata file should be in the XML format used by the save_image_dataset_metadata() routine. - The difference between this function and the version of load_image_dataset() defined above is that this version will also load object part information and thus fully populates the full_object_detection objects. - #images.size() == The number of images loaded from the metadata file. This is all the images listed in the file unless source.should_skip_empty_images() is set to true. - #images.size() == #object_locations.size() - This routine is capable of loading any image format which can be read by the load_image() routine. - #parts_list == a vector that contains the list of object parts found in the input file and loaded into object_locations. - #parts_list is in lexicographic sorted order. - let IGNORED_RECTS denote the vector returned from this function. - IGNORED_RECTS.size() == #object_locations.size() - IGNORED_RECTS == a list of the rectangles which have the "ignore" flag set to true in the input XML file. - for all valid i: - #images[i] == a copy of the i-th image from the dataset. - #object_locations[i] == a vector of all the rectangles associated with #images[i]. These are the rectangles for which source.should_load_box() returns true and are also not marked as "ignore" in the XML file. - IGNORED_RECTS[i] == A vector of all the rectangles associated with #images[i] that are marked as "ignore" but not discarded by source.should_load_box(). - if (source.should_skip_empty_images() == true) then - #object_locations[i].size() != 0 (i.e. we won't load images that don't end up having any object locations) - for all valid j: - #object_locations[i][j].num_parts() == #parts_list.size() - for all valid k: - #object_locations[i][j].part(k) == the location of the part with name #parts_list[k] or OBJECT_PART_NOT_PRESENT if the part was not indicated for object #object_locations[i][j]. !*/ // ---------------------------------------------------------------------------------------- template < typename array_type > std::vector > load_image_dataset ( array_type& images, std::vector >& object_locations, const image_dataset_file& source ); /*! requires - array_type == An array of images. This is anything with an interface that looks like std::vector where a "generic image" is anything that implements the generic image interface defined in dlib/image_processing/generic_image.h. ensures - performs: return load_image_dataset(images, object_locations, source, parts_list); (i.e. this function simply calls the above function and discards the output parts_list. So it is just a convenience function you can call if you don't care about getting the parts list.) !*/ // ---------------------------------------------------------------------------------------- template < typename array_type > std::vector > load_image_dataset ( array_type& images, std::vector >& object_locations, const std::string& filename ); /*! requires - array_type == An array of images. This is anything with an interface that looks like std::vector where a "generic image" is anything that implements the generic image interface defined in dlib/image_processing/generic_image.h. ensures - performs: return load_image_dataset(images, object_locations, image_dataset_file(filename)); (i.e. it ignores box labels and therefore loads all the boxes in the dataset) !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_LOAD_IMAGE_DaTASET_ABSTRACT_Hh_ ================================================ FILE: dlib/data_io/mnist.cpp ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_MNIST_CPp_ #define DLIB_MNIST_CPp_ #include "mnist.h" #include #include "../byte_orderer.h" #include "../uintn.h" // ---------------------------------------------------------------------------------------- namespace dlib { void load_mnist_dataset ( const std::string& folder_name, std::vector >& training_images, std::vector& training_labels, std::vector >& testing_images, std::vector& testing_labels ) { using namespace std; ifstream fin1((folder_name+"/train-images-idx3-ubyte").c_str(), ios::binary); if (!fin1) { fin1.open((folder_name + "/train-images.idx3-ubyte").c_str(), ios::binary); } ifstream fin2((folder_name+"/train-labels-idx1-ubyte").c_str(), ios::binary); if (!fin2) { fin2.open((folder_name + "/train-labels.idx1-ubyte").c_str(), ios::binary); } ifstream fin3((folder_name+"/t10k-images-idx3-ubyte").c_str(), ios::binary); if (!fin3) { fin3.open((folder_name + "/t10k-images.idx3-ubyte").c_str(), ios::binary); } ifstream fin4((folder_name+"/t10k-labels-idx1-ubyte").c_str(), ios::binary); if (!fin4) { fin4.open((folder_name + "/t10k-labels.idx1-ubyte").c_str(), ios::binary); } if (!fin1) throw error("Unable to open file train-images-idx3-ubyte or train-images.idx3-ubyte"); if (!fin2) throw error("Unable to open file train-labels-idx1-ubyte or train-labels.idx1-ubyte"); if (!fin3) throw error("Unable to open file t10k-images-idx3-ubyte or t10k-images.idx3-ubyte"); if (!fin4) throw error("Unable to open file t10k-labels-idx1-ubyte or t10k-labels.idx1-ubyte"); byte_orderer bo; // make sure the files have the contents we expect. uint32 magic, num, nr, nc, num2, num3, num4; fin1.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); fin1.read((char*)&num, sizeof(num)); bo.big_to_host(num); fin1.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr); fin1.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc); if (magic != 2051 || num != 60000 || nr != 28 || nc != 28) throw error("mnist dat files are corrupted."); fin2.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); fin2.read((char*)&num2, sizeof(num2)); bo.big_to_host(num2); if (magic != 2049 || num2 != 60000) throw error("mnist dat files are corrupted."); fin3.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); fin3.read((char*)&num3, sizeof(num3)); bo.big_to_host(num3); fin3.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr); fin3.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc); if (magic != 2051 || num3 != 10000 || nr != 28 || nc != 28) throw error("mnist dat files are corrupted."); fin4.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic); fin4.read((char*)&num4, sizeof(num4)); bo.big_to_host(num4); if (magic != 2049 || num4 != 10000) throw error("mnist dat files are corrupted."); if (!fin1) throw error("Unable to read train-images-idx3-ubyte"); if (!fin2) throw error("Unable to read train-labels-idx1-ubyte"); if (!fin3) throw error("Unable to read t10k-images-idx3-ubyte"); if (!fin4) throw error("Unable to read t10k-labels-idx1-ubyte"); training_images.resize(60000); training_labels.resize(60000); testing_images.resize(10000); testing_labels.resize(10000); for (size_t i = 0; i < training_images.size(); ++i) { training_images[i].set_size(nr,nc); fin1.read((char*)&training_images[i](0,0), nr*nc); } for (size_t i = 0; i < training_labels.size(); ++i) { char l; fin2.read(&l, 1); training_labels[i] = l; } for (size_t i = 0; i < testing_images.size(); ++i) { testing_images[i].set_size(nr,nc); fin3.read((char*)&testing_images[i](0,0), nr*nc); } for (size_t i = 0; i < testing_labels.size(); ++i) { char l; fin4.read(&l, 1); testing_labels[i] = l; } if (!fin1) throw error("Unable to read train-images-idx3-ubyte"); if (!fin2) throw error("Unable to read train-labels-idx1-ubyte"); if (!fin3) throw error("Unable to read t10k-images-idx3-ubyte"); if (!fin4) throw error("Unable to read t10k-labels-idx1-ubyte"); if (fin1.get() != EOF) throw error("Unexpected bytes at end of train-images-idx3-ubyte"); if (fin2.get() != EOF) throw error("Unexpected bytes at end of train-labels-idx1-ubyte"); if (fin3.get() != EOF) throw error("Unexpected bytes at end of t10k-images-idx3-ubyte"); if (fin4.get() != EOF) throw error("Unexpected bytes at end of t10k-labels-idx1-ubyte"); } } // ---------------------------------------------------------------------------------------- #endif // DLIB_MNIST_CPp_ ================================================ FILE: dlib/data_io/mnist.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_MNIST_Hh_ #define DLIB_MNIST_Hh_ #include "mnist_abstract.h" #include #include #include "../matrix.h" // ---------------------------------------------------------------------------------------- namespace dlib { void load_mnist_dataset ( const std::string& folder_name, std::vector >& training_images, std::vector& training_labels, std::vector >& testing_images, std::vector& testing_labels ); } // ---------------------------------------------------------------------------------------- #ifdef NO_MAKEFILE #include "mnist.cpp" #endif #endif // DLIB_MNIST_Hh_ ================================================ FILE: dlib/data_io/mnist_abstract.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_MNIST_ABSTRACT_Hh_ #ifdef DLIB_MNIST_ABSTRACT_Hh_ #include #include #include "../matrix.h" // ---------------------------------------------------------------------------------------- namespace dlib { void load_mnist_dataset ( const std::string& folder_name, std::vector >& training_images, std::vector& training_labels, std::vector >& testing_images, std::vector& testing_labels ); /*! ensures - Attempts to load the MNIST dataset from the hard drive. This is the dataset of handwritten digits available from http://yann.lecun.com/exdb/mnist/. In particular, the 4 files comprising the MNIST dataset should be present in the folder indicated by folder_name. These four files are: - train-images-idx3-ubyte - train-labels-idx1-ubyte - t10k-images-idx3-ubyte - t10k-labels-idx1-ubyte - #training_images == The 60,000 training images from the dataset. - #training_labels == The labels for the contents of #training_images. I.e. #training_labels[i] is the label of #training_images[i]. - #testing_images == The 10,000 testing images from the dataset. - #testing_labels == The labels for the contents of #testing_images. I.e. #testing_labels[i] is the label of #testing_images[i]. throws - dlib::error if some problem prevents us from loading the data or the files can't be found. !*/ } // ---------------------------------------------------------------------------------------- #endif // DLIB_MNIST_ABSTRACT_Hh_ ================================================ FILE: dlib/data_io.h ================================================ // Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DATA_Io_HEADER #define DLIB_DATA_Io_HEADER #include "data_io/libsvm_io.h" #include "data_io/image_dataset_metadata.h" #include "data_io/mnist.h" #include "data_io/cifar.h" #include "data_io/arc_agi.h" #ifndef DLIB_ISO_CPP_ONLY #include "data_io/load_image_dataset.h" #endif #endif // DLIB_DATA_Io_HEADER ================================================ FILE: dlib/dir_nav/dir_nav_extensions.cpp ================================================ // Copyright (C) 2009 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DIR_NAV_EXTENSIONs_CPP_ #define DLIB_DIR_NAV_EXTENSIONs_CPP_ #include "dir_nav_extensions.h" namespace dlib { // ---------------------------------------------------------------------------------------- namespace implementation_details { void get_all_sub_dirs ( const directory& top_of_tree, unsigned long max_depth, std::vector& result, std::vector& temp ) { if (max_depth > 0) { top_of_tree.get_dirs(temp); const unsigned long start = result.size(); result.insert(result.end(), temp.begin(), temp.end()); const unsigned long end = start + temp.size(); for (unsigned long i = start; i < end; ++i) { get_all_sub_dirs(result[i], max_depth-1, result, temp); } } } } // ---------------------------------------------------------------------------------------- bool file_exists ( const std::string& filename ) { try { dlib::file temp(filename); return true; } catch (file::file_not_found&) { return false; } } // ---------------------------------------------------------------------------------------- bool directory_exists ( const std::string& dirname ) { try { dlib::directory temp(dirname); return true; } catch (directory::dir_not_found&) { return false; } } // ---------------------------------------------------------------------------------------- directory get_parent_directory ( const directory& dir ) { return dir.get_parent(); } // ---------------------------------------------------------------------------------------- directory get_parent_directory ( const file& f ) { if (f.full_name().size() == 0) return directory(); std::string::size_type pos = f.full_name().find_last_of("\\/"); if (pos == std::string::npos) return directory(); return directory(f.full_name().substr(0,pos)); } // ---------------------------------------------------------------------------------------- std::string select_oldest_file ( const std::string& filename1, const std::string& filename2 ) { file f1, f2; try{f1 = file(filename1);} catch(file::file_not_found&) { return filename1; } try{f2 = file(filename2);} catch(file::file_not_found&) { return filename2; } if (f1.last_modified() < f2.last_modified()) return filename1; else return filename2; } // ---------------------------------------------------------------------------------------- std::string select_newest_file ( const std::string& filename1, const std::string& filename2 ) { file f1, f2; try{f1 = file(filename1);} catch(file::file_not_found&) { return filename2; } try{f2 = file(filename2);} catch(file::file_not_found&) { return filename1; } if (f1.last_modified() > f2.last_modified()) return filename1; else return filename2; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_DIR_NAV_EXTENSIONs_CPP_ ================================================ FILE: dlib/dir_nav/dir_nav_extensions.h ================================================ // Copyright (C) 2009 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DIR_NAV_EXTENSIONs_H_ #define DLIB_DIR_NAV_EXTENSIONs_H_ #include #include #include #include "dir_nav_extensions_abstract.h" #include "../dir_nav.h" #include "../string.h" namespace dlib { // ---------------------------------------------------------------------------------------- bool file_exists ( const std::string& filename ); // ---------------------------------------------------------------------------------------- bool directory_exists ( const std::string& dirname ); // ---------------------------------------------------------------------------------------- namespace implementation_details { void get_all_sub_dirs ( const directory& top_of_tree, unsigned long max_depth, std::vector& result, std::vector& temp ); } // ---------------------------------------------------------------------------------------- template const std::vector get_files_in_directory_tree ( const directory& top_of_tree, const T& add_file, unsigned long max_depth = 30 ) { std::vector result, temp; std::vector dirs, dirs_temp; dirs.push_back(top_of_tree); // get all the directories in the tree first implementation_details::get_all_sub_dirs(top_of_tree, max_depth, dirs, dirs_temp); // now just loop over all the directories and pick out the files we want to keep for (unsigned long d = 0; d < dirs.size(); ++d) { dirs[d].get_files(temp); // pick out the members of temp that we should keep for (unsigned long i = 0; i < temp.size(); ++i) { if (add_file(temp[i])) result.push_back(temp[i]); } } return result; } // ---------------------------------------------------------------------------------------- class match_ending { public: match_ending ( const std::string& ending_ ) : ending(ending_) {} bool operator() ( const file& f ) const { // if the ending is bigger than f's name then it obviously doesn't match if (ending.size() > f.name().size()) return false; // now check if the actual characters that make up the end of the file name // matches what is in ending. return std::equal(ending.begin(), ending.end(), f.name().end()-ending.size()); } private: std::string ending; }; // ---------------------------------------------------------------------------------------- class match_endings { public: match_endings ( const std::string& endings_ ) { const std::vector& s = split(endings_); for (unsigned long i = 0; i < s.size(); ++i) { endings.push_back(match_ending(s[i])); } } bool operator() ( const file& f ) const { for (unsigned long i = 0; i < endings.size(); ++i) { if (endings[i](f)) return true; } return false; } private: std::vector endings; }; // ---------------------------------------------------------------------------------------- class match_all { public: bool operator() ( const file& ) const { return true; } }; // ---------------------------------------------------------------------------------------- directory get_parent_directory ( const directory& dir ); // ---------------------------------------------------------------------------------------- directory get_parent_directory ( const file& f ); // ---------------------------------------------------------------------------------------- std::string select_oldest_file ( const std::string& filename1, const std::string& filename2 ); // ---------------------------------------------------------------------------------------- std::string select_newest_file ( const std::string& filename1, const std::string& filename2 ); // ---------------------------------------------------------------------------------------- } #ifdef NO_MAKEFILE #include "dir_nav_extensions.cpp" #endif #endif // DLIB_DIR_NAV_EXTENSIONs_H_ ================================================ FILE: dlib/dir_nav/dir_nav_extensions_abstract.h ================================================ // Copyright (C) 2009 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_DIR_NAV_EXTENSIONs_ABSTRACT_ #ifdef DLIB_DIR_NAV_EXTENSIONs_ABSTRACT_ #include #include #include "dir_nav_kernel_abstract.h" namespace dlib { // ---------------------------------------------------------------------------------------- bool file_exists ( const std::string& filename ); /*! ensures - if (a file with the given filename exists) then - returns true - else - returns false !*/ // ---------------------------------------------------------------------------------------- bool directory_exists ( const std::string& dirname ); /*! ensures - if (a directory with the given dirname exists) then - returns true - else - returns false !*/ // ---------------------------------------------------------------------------------------- template const std::vector get_files_in_directory_tree ( const directory& top_of_tree, const T& add_file, unsigned long max_depth = 30 ); /*! requires - add_file must be a function object with the following prototype: bool add_file (file f); ensures - performs a recursive search through the directory top_of_tree and all its sub-directories (up to the given max depth). All files in these directories are examined by passing them to add_file() and if it returns true then they will be included in the returned std::vector object. - Note that a max_depth of 0 means that only the files in the directory top_of_tree will be considered. A depth of 1 means that only files in top_of_tree and its immediate sub-directories will be considered. And so on... !*/ // ---------------------------------------------------------------------------------------- class match_ending { /*! WHAT THIS OBJECT REPRESENTS This is a simple function object that can be used with the above get_files_in_directory_tree() function. This object just looks for files with a certain ending. !*/ public: match_ending ( const std::string& ending ); /*! ensures - this object will be a function that checks if a file has a name that ends with the given ending string. !*/ bool operator() ( const file& f ) const; /*! ensures - if (the file f has a name that ends with the ending string given to this object's constructor) then - returns true - else - returns false !*/ }; // ---------------------------------------------------------------------------------------- class match_endings { /*! WHAT THIS OBJECT REPRESENTS This is a simple function object that can be used with the above get_files_in_directory_tree() function. This object allows you to look for files with a number of different endings. !*/ public: match_endings ( const std::string& ending_list ); /*! ensures - ending_list is interpreted as a whitespace separated list of file endings. - this object will be a function that checks if a file has a name that ends with one of the strings in ending_list. !*/ bool operator() ( const file& f ) const; /*! ensures - if (the file f has a name that ends with one of the ending strings given to this object's constructor) then - returns true - else - returns false !*/ }; // ---------------------------------------------------------------------------------------- class match_all { /*! WHAT THIS OBJECT REPRESENTS This is a simple function object that can be used with the above get_files_in_directory_tree() function. This object matches all files. !*/ public: bool operator() ( const file& f ) const; /*! ensures - returns true (i.e. this function doesn't do anything. It just says it matches all files no matter what) !*/ }; // ---------------------------------------------------------------------------------------- directory get_parent_directory ( const directory& dir ); /*! ensures - returns the parent directory of dir. In particular, this function returns the value of dir.get_parent() !*/ // ---------------------------------------------------------------------------------------- directory get_parent_directory ( const file& f ); /*! ensures - if (f.full_name() != "") then - returns the directory which contains the given file - else - returns a default initialized directory (i.e. directory()) !*/ // ---------------------------------------------------------------------------------------- std::string select_oldest_file ( const std::string& filename1, const std::string& filename2 ); /*! ensures - Checks the last modification times of the two given files and returns the filename of the oldest file, i.e., the file that has gone longest since being modified. Ties are broken arbitrarily. - For the purpose of comparison, a file that doesn't exist is presumed to have a last modification time of -infinity (i.e. very far in the past). !*/ // ---------------------------------------------------------------------------------------- std::string select_newest_file ( const std::string& filename1, const std::string& filename2 ); /*! ensures - Checks the last modification times of the two given files and returns the filename that was most recently modified. Ties are broken arbitrarily. - For the purpose of comparison, a file that doesn't exist is presumed to have a last modification time of -infinity (i.e. very far in the past). !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_DIR_NAV_EXTENSIONs_ABSTRACT_ ================================================ FILE: dlib/dir_nav/dir_nav_kernel_1.cpp ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DIR_NAV_KERNEL_1_CPp_ #define DLIB_DIR_NAV_KERNEL_1_CPp_ #include "../platform.h" #ifdef WIN32 #include "dir_nav_kernel_1.h" #include "../string.h" #ifdef __BORLANDC__ // Apparently the borland compiler doesn't define this. #define INVALID_FILE_ATTRIBUTES ((DWORD)-1) #endif namespace dlib { // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // file object implementation // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- void file:: init ( const std::string& name ) { using namespace std; char buf[3000]; char* str; if (GetFullPathNameA(name.c_str(),sizeof(buf),buf,&str) == 0) { // the file was not found throw file_not_found("Unable to find file " + name); } state.full_name = buf; string::size_type pos = state.full_name.find_last_of(directory::get_separator()); if (pos == string::npos) { // no valid full path has no separator characters. throw file_not_found("Unable to find file " + name); } state.name = state.full_name.substr(pos+1); // now find the size of this file WIN32_FIND_DATAA data; HANDLE ffind = FindFirstFileA(state.full_name.c_str(), &data); if (ffind == INVALID_HANDLE_VALUE || (data.dwFileAttributes&FILE_ATTRIBUTE_DIRECTORY) != 0) { throw file_not_found("Unable to find file " + name); } else { uint64 temp = data.nFileSizeHigh; temp <<= 32; temp |= data.nFileSizeLow; state.file_size = temp; FindClose(ffind); ULARGE_INTEGER ull; ull.LowPart = data.ftLastWriteTime.dwLowDateTime; ull.HighPart = data.ftLastWriteTime.dwHighDateTime; std::chrono::nanoseconds epoch(100 * (ull.QuadPart - 116444736000000000)); state.last_modified = std::chrono::time_point(std::chrono::duration_cast(epoch)); } } // ---------------------------------------------------------------------------------------- bool file:: operator == ( const file& rhs ) const { using namespace std; if (state.full_name.size() != rhs.state.full_name.size()) return false; // compare the strings but ignore the case because file names // are not case sensitive on windows return tolower(state.full_name) == tolower(rhs.state.full_name); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // directory object implementation // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- void directory:: init ( const std::string& name ) { using namespace std; char buf[3000]; char* str; if (GetFullPathNameA(name.c_str(),sizeof(buf),buf,&str) == 0) { // the directory was not found throw dir_not_found("Unable to find directory " + name); } state.full_name = buf; const char sep = get_separator(); if (is_root_path(state.full_name) == false) { // ensure that thre is not a trialing separator if (state.full_name[state.full_name.size()-1] == sep) state.full_name.erase(state.full_name.size()-1); // pick out the directory name string::size_type pos = state.full_name.find_last_of(sep); state.name = state.full_name.substr(pos+1); } else { // ensure that there is a trailing separator if (state.full_name[state.full_name.size()-1] != sep) state.full_name += sep; } // now check that this is actually a valid directory DWORD attribs = GetFileAttributesA(state.full_name.c_str()); if (attribs == INVALID_FILE_ATTRIBUTES || (attribs&FILE_ATTRIBUTE_DIRECTORY) == 0) { // the directory was not found throw dir_not_found("Unable to find directory " + name); } } // ---------------------------------------------------------------------------------------- char directory:: get_separator ( ) { return '\\'; } // ---------------------------------------------------------------------------------------- bool directory:: operator == ( const directory& rhs ) const { using namespace std; if (state.full_name.size() != rhs.state.full_name.size()) return false; // compare the strings but ignore the case because file names // are not case sensitive on windows return tolower(state.full_name) == tolower(rhs.state.full_name); } // ---------------------------------------------------------------------------------------- const directory directory:: get_parent ( ) const { using namespace std; // if *this is the root then just return *this if (is_root()) { return *this; } else { directory temp; const char sep = get_separator(); string::size_type pos = state.full_name.find_last_of(sep); temp.state.full_name = state.full_name.substr(0,pos); if ( is_root_path(temp.state.full_name)) { temp.state.full_name += sep; } else { pos = temp.state.full_name.find_last_of(sep); if (pos != string::npos) { temp.state.name = temp.state.full_name.substr(pos+1); } else { temp.state.full_name += sep; } } return temp; } } // ---------------------------------------------------------------------------------------- bool directory:: is_root_path ( const std::string& path ) const { using namespace std; const char sep = get_separator(); bool root_path = false; if (path.size() > 2 && path[0] == sep && path[1] == sep) { // in this case this is a windows share path string::size_type pos = path.find_first_of(sep,2); if (pos != string::npos) { pos = path.find_first_of(sep,pos+1); if (pos == string::npos && path[path.size()-1] != sep) root_path = true; else if (pos == path.size()-1) root_path = true; } } else if ( (path.size() == 2 || path.size() == 3) && path[1] == ':') { // if this is a valid windows path then it must be a root path root_path = true; } return root_path; } // ---------------------------------------------------------------------------------------- } #endif // WIN32 #endif // DLIB_DIR_NAV_KERNEL_1_CPp_ ================================================ FILE: dlib/dir_nav/dir_nav_kernel_1.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DIR_NAV_KERNEl_1_ #define DLIB_DIR_NAV_KERNEl_1_ #ifdef DLIB_ISO_CPP_ONLY #error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it." #endif #include "../platform.h" #include "dir_nav_kernel_abstract.h" #include #include "../uintn.h" #include "../algs.h" #include "../windows_magic.h" #include #include #include "../stl_checked.h" #include "../enable_if.h" #include "../queue.h" #include namespace dlib { // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // file object // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- class file { /*! INITIAL VALUES state.name == name() state.full_name == full_name() state.file_size == size() state.last_modified == last_modified() CONVENTION state.name == name() state.full_name == full_name() state.file_size == size() state.last_modified == last_modified() !*/ friend class directory; struct data { uint64 file_size; std::string name; std::string full_name; std::chrono::time_point last_modified; }; void init ( const std::string& name); public: struct private_constructor{}; inline file ( const std::string& name, const std::string& full_name, const uint64 file_size, const std::chrono::time_point& last_modified, private_constructor ) { state.file_size = file_size; state.name = name; state.full_name = full_name; state.last_modified = last_modified; } class file_not_found : public error { public: file_not_found(const std::string& s): error(s){} }; inline file ( ) { state.file_size = 0; } file ( const std::string& name ) { init(name); } file ( const char* name ) { init(name); } inline const std::string& name ( ) const { return state.name; } inline const std::string& full_name ( ) const { return state.full_name; } operator std::string ( ) const { return full_name(); } inline uint64 size ( ) const { return state.file_size; } inline std::chrono::time_point last_modified ( ) const { return state.last_modified; } bool operator == ( const file& rhs ) const; bool operator != ( const file& rhs ) const { return !(*this == rhs); } inline bool operator < ( const file& item ) const { return full_name() < item.full_name(); } inline void swap ( file& item ) { exchange(state,item.state); } private: data state; }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // directory object // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- class directory { /*! INITIAL VALUES state.name == name() state.full_name == full_name() CONVENTION state.name == name() state.full_name == full_name() is_root() == state.name.size() == 0 !*/ void init (const std::string& name); public: struct data { std::string name; std::string full_name; }; /* The reason we don't just make this constructor actually private is because doing it this way avoids a bug that sometimes occurs in visual studio 7.1. The bug has something to do with templated friend functions such as the get_filesystem_roots() function below if it was declared as a friend template of this class. */ struct private_constructor{}; inline directory ( const std::string& name, const std::string& full_name, private_constructor ) { state.name = name; state.full_name = full_name; } class dir_not_found : public error { public: dir_not_found(const std::string& s):error(s){} }; class listing_error : public error { public: listing_error(const std::string& s):error(s){} }; inline directory ( ) { } directory ( const std::string& name ) { init(name); } directory ( const char* name ) { init(name); } static char get_separator ( ); template < typename queue_of_files > void get_files ( queue_of_files& files ) const; template < typename queue_of_dirs > void get_dirs ( queue_of_dirs& dirs ) const; std::vector get_files ( ) const { std::vector temp_vector; get_files(temp_vector); return temp_vector; } std::vector get_dirs ( ) const { std::vector temp_vector; get_dirs(temp_vector); return temp_vector; } const directory get_parent ( ) const; inline bool is_root ( ) const { return state.name.size() == 0; } inline const std::string& name ( ) const { return state.name; } inline const std::string& full_name ( ) const { return state.full_name; } operator std::string ( ) const { return full_name(); } bool operator == ( const directory& rhs ) const; bool operator != ( const directory& rhs ) const { return !(*this == rhs); } inline bool operator < ( const directory& item ) const { return full_name() < item.full_name(); } inline void swap ( directory& item ) { exchange(state,item.state); } private: // member data data state; bool is_root_path ( const std::string& path ) const; /*! ensures - returns true if path is a root path. Note that this function considers root paths that don't have a trailing separator to also be valid. !*/ }; // ---------------------------------------------------------------------------------------- inline std::ostream& operator<< ( std::ostream& out, const directory& item ) { out << (std::string)item; return out; } inline std::ostream& operator<< ( std::ostream& out, const file& item ) { out << (std::string)item; return out; } // ---------------------------------------------------------------------------------------- template < typename queue_of_dir > typename disable_if,void>::type get_filesystem_roots ( queue_of_dir& roots ) { roots.clear(); const DWORD mask = GetLogicalDrives(); DWORD bit = 1; char buf[] = "A:\\"; do { if (mask & bit) { directory dir("",buf,directory::private_constructor()); roots.enqueue(dir); } bit <<= 1; ++buf[0]; } while (buf[0] != 'Z'); } template < typename queue_of_dir > typename enable_if,void>::type get_filesystem_roots ( queue_of_dir& roots ) { roots.clear(); const DWORD mask = GetLogicalDrives(); DWORD bit = 1; char buf[] = "A:\\"; do { if (mask & bit) { directory dir("",buf,directory::private_constructor()); roots.push_back(dir); } bit <<= 1; ++buf[0]; } while (buf[0] != 'Z'); } // ---------------------------------------------------------------------------------------- inline void swap ( file& a, file& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- inline void swap ( directory& a, directory& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // templated member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename queue_of_files > typename disable_if,void>::type directory_helper_get_files ( const directory::data& state, queue_of_files& files ) { typedef directory::listing_error listing_error; typedef file::private_constructor private_constructor; files.clear(); if (state.full_name.size() == 0) throw listing_error("This directory object currently doesn't represent any directory."); HANDLE ffind = INVALID_HANDLE_VALUE; try { WIN32_FIND_DATAA data; std::string path = state.full_name; // ensure that the path ends with a separator if (path[path.size()-1] != directory::get_separator()) path += directory::get_separator(); ffind = FindFirstFileA((path+"*").c_str(), &data); if (ffind == INVALID_HANDLE_VALUE) { throw listing_error("Unable to list the contents of " + state.full_name); } bool no_more_files = false; do { if ((data.dwFileAttributes&FILE_ATTRIBUTE_DIRECTORY) == 0) { uint64 file_size = data.nFileSizeHigh; file_size <<= 32; file_size |= data.nFileSizeLow; ULARGE_INTEGER ull; ull.LowPart = data.ftLastWriteTime.dwLowDateTime; ull.HighPart = data.ftLastWriteTime.dwHighDateTime; std::chrono::nanoseconds epoch(100 * (ull.QuadPart - 116444736000000000)); auto last_modified = std::chrono::time_point(std::chrono::duration_cast(epoch)); // this is a file so add it to the queue file temp(data.cFileName,path+data.cFileName,file_size, last_modified, private_constructor()); files.enqueue(temp); } if (FindNextFileA(ffind,&data) == 0) { // an error occurred if ( GetLastError() == ERROR_NO_MORE_FILES) { // there are no more files no_more_files = true; } else { // there was an error throw listing_error("Unable to list the contents of " + state.full_name); } } } while (no_more_files == false); FindClose(ffind); ffind = INVALID_HANDLE_VALUE; } catch (...) { if (ffind != INVALID_HANDLE_VALUE) FindClose(ffind); files.clear(); throw; } } // ---------------------------------------------------------------------------------------- template < typename queue_of_files > typename enable_if,void>::type directory_helper_get_files ( const directory::data& state, queue_of_files& files ) { queue::kernel_2a temp_files; directory_helper_get_files(state,temp_files); files.clear(); // copy the queue of files into the vector temp_files.reset(); while (temp_files.move_next()) { files.push_back(temp_files.element()); } } // ---------------------------------------------------------------------------------------- template < typename queue_of_files > void directory:: get_files ( queue_of_files& files ) const { // the reason for this indirection here is because it avoids a bug in // the mingw version of gcc directory_helper_get_files(state,files); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename queue_of_dirs > typename disable_if,void>::type directory_helper_get_dirs ( const directory::data& state, queue_of_dirs& dirs ) { typedef directory::listing_error listing_error; typedef directory::private_constructor private_constructor; dirs.clear(); if (state.full_name.size() == 0) throw listing_error("This directory object currently doesn't represent any directory."); HANDLE dfind = INVALID_HANDLE_VALUE; try { WIN32_FIND_DATAA data; std::string path = state.full_name; // ensure that the path ends with a separator if (path[path.size()-1] != directory::get_separator()) path += directory::get_separator(); dfind = FindFirstFileA((path+"*").c_str(), &data); if (dfind == INVALID_HANDLE_VALUE) { throw listing_error("Unable to list the contents of " + state.full_name); } bool no_more_files = false; do { std::string tname(data.cFileName); if ((data.dwFileAttributes&FILE_ATTRIBUTE_DIRECTORY) != 0 && tname != "." && tname != "..") { // this is a directory so add it to the queue directory temp(tname,path+tname,private_constructor()); dirs.enqueue(temp); } if (FindNextFileA(dfind,&data) == 0) { // an error occurred if ( GetLastError() == ERROR_NO_MORE_FILES) { // there are no more files no_more_files = true; } else { // there was an error throw listing_error("Unable to list the contents of " + state.full_name); } } } while (no_more_files == false); FindClose(dfind); dfind = INVALID_HANDLE_VALUE; } catch (...) { if (dfind != INVALID_HANDLE_VALUE) FindClose(dfind); dirs.clear(); throw; } } // ---------------------------------------------------------------------------------------- template < typename queue_of_dirs > typename enable_if,void>::type directory_helper_get_dirs ( const directory::data& state, queue_of_dirs& dirs ) { queue::kernel_2a temp_dirs; directory_helper_get_dirs(state,temp_dirs); dirs.clear(); // copy the queue of dirs into the vector temp_dirs.reset(); while (temp_dirs.move_next()) { dirs.push_back(temp_dirs.element()); } } // ---------------------------------------------------------------------------------------- template < typename queue_of_dirs > void directory:: get_dirs ( queue_of_dirs& dirs ) const { // the reason for this indirection here is because it avoids a bug in // the mingw version of gcc directory_helper_get_dirs(state,dirs); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- } #ifdef NO_MAKEFILE #include "dir_nav_kernel_1.cpp" #endif #endif // DLIB_DIR_NAV_KERNEl_1_ ================================================ FILE: dlib/dir_nav/dir_nav_kernel_2.cpp ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DIR_NAV_KERNEL_2_CPp_ #define DLIB_DIR_NAV_KERNEL_2_CPp_ #include "../platform.h" #ifdef DLIB_POSIX #include "dir_nav_kernel_2.h" namespace dlib { // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // file object implementation // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- void file:: init ( const std::string& name ) { using namespace std; char buf[PATH_MAX]; if (realpath(name.c_str(),buf) == 0) { // the file was not found throw file_not_found("Unable to find file " + name); } state.full_name = buf; string::size_type pos = state.full_name.find_last_of(directory::get_separator()); if (pos == string::npos) { // no valid full path has no separtor characters. throw file_not_found("Unable to find file " + name); } state.name = state.full_name.substr(pos+1); // now find the size of this file struct stat64 buffer; if (::stat64(state.full_name.c_str(), &buffer) || S_ISDIR(buffer.st_mode)) { // there was an error during the call to stat64 or // name is actually a directory throw file_not_found("Unable to find file " + name); } else { state.file_size = static_cast(buffer.st_size); state.last_modified = std::chrono::system_clock::from_time_t(buffer.st_mtime); #ifdef _BSD_SOURCE state.last_modified += std::chrono::duration_cast(std::chrono::nanoseconds(buffer.st_atim.tv_nsec)); #endif } } // ---------------------------------------------------------------------------------------- bool file:: operator == ( const file& rhs ) const { using namespace std; if (state.full_name.size() == 0 && rhs.state.full_name.size() == 0) return true; // These files might have different names but actually represent the same // file due to the presence of symbolic links. char buf[PATH_MAX]; string left, right; if (realpath(state.full_name.c_str(),buf) == 0) return false; left = buf; if (realpath(rhs.state.full_name.c_str(),buf) == 0) return false; right = buf; return (left == right); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // directory object implementation // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- void directory:: init ( const std::string& name ) { using namespace std; char buf[PATH_MAX]; if (realpath(name.c_str(),buf) == 0) { // the directory was not found throw dir_not_found("Unable to find directory " + name); } state.full_name = buf; const char sep = get_separator(); if (is_root_path(state.full_name) == false) { // ensure that thre is not a trialing separator if (state.full_name[state.full_name.size()-1] == sep) state.full_name.erase(state.full_name.size()-1); // pick out the directory name string::size_type pos = state.full_name.find_last_of(sep); state.name = state.full_name.substr(pos+1); } else { // ensure that there is a trailing separator if (state.full_name[state.full_name.size()-1] != sep) state.full_name += sep; } struct stat64 buffer; // now check that this is actually a valid directory if (::stat64(state.full_name.c_str(),&buffer)) { // the directory was not found throw dir_not_found("Unable to find directory " + name); } else if (S_ISDIR(buffer.st_mode) == 0) { // It is not a directory throw dir_not_found("Unable to find directory " + name); } } // ---------------------------------------------------------------------------------------- char directory:: get_separator ( ) { return '/'; } // ---------------------------------------------------------------------------------------- bool directory:: operator == ( const directory& rhs ) const { using namespace std; if (state.full_name.size() == 0 && rhs.state.full_name.size() == 0) return true; // These directories might have different names but actually represent the same // directory due to the presence of symbolic links. char buf[PATH_MAX]; string left, right; if (realpath(state.full_name.c_str(),buf) == 0) return false; left = buf; if (realpath(rhs.state.full_name.c_str(),buf) == 0) return false; right = buf; return (left == right); } // ---------------------------------------------------------------------------------------- const directory directory:: get_parent ( ) const { using namespace std; // if *this is the root then just return *this if (is_root()) { return *this; } else { directory temp; const char sep = get_separator(); string::size_type pos = state.full_name.find_last_of(sep); temp.state.full_name = state.full_name.substr(0,pos); if ( is_root_path(temp.state.full_name)) { temp.state.full_name += sep; } else { pos = temp.state.full_name.find_last_of(sep); if (pos != string::npos) { temp.state.name = temp.state.full_name.substr(pos+1); } else { temp.state.full_name += sep; } } return temp; } } // ---------------------------------------------------------------------------------------- bool directory:: is_root_path ( const std::string& path ) const { const char sep = get_separator(); if (path.size() == 1 && path[0] == sep) return true; else return false; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_POSIX #endif // DLIB_DIR_NAV_KERNEL_2_CPp_ ================================================ FILE: dlib/dir_nav/dir_nav_kernel_2.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DIR_NAV_KERNEl_2_ #define DLIB_DIR_NAV_KERNEl_2_ #ifdef DLIB_ISO_CPP_ONLY #error "DLIB_ISO_CPP_ONLY is defined so you can't use this OS dependent code. Turn DLIB_ISO_CPP_ONLY off if you want to use it." #endif #include "dir_nav_kernel_abstract.h" #include #include "../uintn.h" #include "../algs.h" #include #include #include #include #include #include #include #include #include #if !defined(__USE_LARGEFILE64 ) && !defined(_LARGEFILE64_SOURCE) #define stat64 stat #endif #include #include "../stl_checked.h" #include "../enable_if.h" #include "../queue.h" namespace dlib { // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // file object // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- class file { /*! INITIAL VALUES state.name == name() state.full_name == full_name() state.file_size == size() state.last_modified == last_modified() CONVENTION state.name == name() state.full_name == full_name() state.file_size == size() state.last_modified == last_modified() !*/ friend class directory; struct data { uint64 file_size; std::string name; std::string full_name; std::chrono::time_point last_modified; }; void init(const std::string& name); public: struct private_constructor{}; inline file ( const std::string& name, const std::string& full_name, const uint64 file_size, const std::chrono::time_point& last_modified, private_constructor ) { state.file_size = file_size; state.name = name; state.full_name = full_name; state.last_modified = last_modified; } class file_not_found : public error { public: file_not_found(const std::string& s): error(s){} }; inline file ( ) { state.file_size = 0; } file ( const std::string& name ) { init(name); } file ( const char* name ) { init(name); } inline const std::string& name ( ) const { return state.name; } inline const std::string& full_name ( ) const { return state.full_name; } inline uint64 size ( ) const { return state.file_size; } inline std::chrono::time_point last_modified ( ) const { return state.last_modified; } operator std::string ( ) const { return full_name(); } bool operator == ( const file& rhs ) const; bool operator != ( const file& rhs ) const { return !(*this == rhs); } inline bool operator < ( const file& item ) const { return full_name() < item.full_name(); } inline void swap ( file& item ) { exchange(state,item.state); } private: // member data data state; }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // directory object // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- class directory { /*! INITIAL VALUES state.name == name() state.full_name == full_name() CONVENTION state.name == name() state.full_name == full_name() is_root() == state.name.size() == 0 !*/ void init(const std::string& name); public: struct private_constructor{}; inline directory ( const std::string& name, const std::string& full_name, private_constructor ) { state.name = name; state.full_name = full_name; } struct data { std::string name; std::string full_name; }; class dir_not_found : public error { public: dir_not_found(const std::string& s):error(s){} }; class listing_error : public error { public: listing_error(const std::string& s):error(s){} }; inline directory ( ) { } directory ( const std::string& name ) { init(name); } directory ( const char* name ) { init(name); } static char get_separator ( ); template < typename queue_of_files > void get_files ( queue_of_files& files ) const; template < typename queue_of_dirs > void get_dirs ( queue_of_dirs& dirs ) const; std::vector get_files ( ) const { std::vector temp_vector; get_files(temp_vector); return temp_vector; } std::vector get_dirs ( ) const { std::vector temp_vector; get_dirs(temp_vector); return temp_vector; } const directory get_parent ( ) const; inline bool is_root ( ) const { return state.name.size() == 0; } inline const std::string& name ( ) const { return state.name; } inline const std::string& full_name ( ) const { return state.full_name; } operator std::string ( ) const { return full_name(); } bool operator == ( const directory& rhs ) const; bool operator != ( const directory& rhs ) const { return !(*this == rhs); } inline bool operator < ( const directory& item ) const { return full_name() < item.full_name(); } inline void swap ( directory& item ) { exchange(state,item.state); } private: // member data data state; bool is_root_path ( const std::string& path ) const; /*! ensures - returns true if path is a root path. Note that this function considers root paths that don't have a trailing separator to also be valid. !*/ }; // ---------------------------------------------------------------------------------------- inline std::ostream& operator<< ( std::ostream& out, const directory& item ) { out << (std::string)item; return out; } inline std::ostream& operator<< ( std::ostream& out, const file& item ) { out << (std::string)item; return out; } // ---------------------------------------------------------------------------------------- inline void swap ( file& a, file& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- inline void swap ( directory& a, directory& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // templated member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename queue_of_files > typename disable_if,void>::type directory_helper_get_files ( const directory::data& state, queue_of_files& files ) { files.clear(); if (state.full_name.size() == 0) throw directory::listing_error("This directory object currently doesn't represent any directory."); DIR* ffind = 0; struct dirent* data; struct stat64 buffer; try { std::string path = state.full_name; // ensure that the path ends with a separator if (path[path.size()-1] != directory::get_separator()) path += directory::get_separator(); // get a handle to something we can search with ffind = opendir(state.full_name.c_str()); if (ffind == 0) { throw directory::listing_error("Unable to list the contents of " + state.full_name); } while(true) { errno = 0; if ( (data = readdir(ffind)) == 0) { // there was an error or no more files if ( errno == 0) { // there are no more files break; } else { // there was an error throw directory::listing_error("Unable to list the contents of " + state.full_name); } } uint64 file_size; // get a stat64 structure so we can see if this is a file if (::stat64((path+data->d_name).c_str(), &buffer) != 0) { // this might be a broken symbolic link. We can check by calling // readlink and seeing if it finds anything. char buf[PATH_MAX]; ssize_t temp = readlink((path+data->d_name).c_str(),buf,sizeof(buf)); if (temp == -1) throw directory::listing_error("Unable to list the contents of " + state.full_name); else file_size = static_cast(temp); } else { file_size = static_cast(buffer.st_size); } auto last_modified = std::chrono::system_clock::from_time_t(buffer.st_mtime); #ifdef _BSD_SOURCE last_modified += std::chrono::duration_cast(std::chrono::nanoseconds(buffer.st_atim.tv_nsec)); #endif if (S_ISDIR(buffer.st_mode) == 0) { // this is actually a file file temp( data->d_name, path+data->d_name, file_size, last_modified, file::private_constructor() ); files.enqueue(temp); } } // while (true) if (ffind != 0) { while (closedir(ffind)) { if (errno != EINTR) break; } ffind = 0; } } catch (...) { if (ffind != 0) { while (closedir(ffind)) { if (errno != EINTR) break; } ffind = 0; } files.clear(); throw; } } // ---------------------------------------------------------------------------------------- template < typename queue_of_files > typename enable_if,void>::type directory_helper_get_files ( const directory::data& state, queue_of_files& files ) { queue::kernel_2a temp_files; directory_helper_get_files(state,temp_files); files.clear(); // copy the queue of files into the vector temp_files.reset(); while (temp_files.move_next()) { files.push_back(temp_files.element()); } } // ---------------------------------------------------------------------------------------- template < typename queue_of_files > void directory:: get_files ( queue_of_files& files ) const { // the reason for this indirection here is because it avoids a bug in // the cygwin version of gcc directory_helper_get_files(state,files); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename queue_of_dirs > typename disable_if,void>::type directory_helper_get_dirs ( const directory::data& state, queue_of_dirs& dirs ) { dirs.clear(); if (state.full_name.size() == 0) throw directory::listing_error("This directory object currently doesn't represent any directory."); DIR* ffind = 0; struct dirent* data; struct stat64 buffer; try { std::string path = state.full_name; // ensure that the path ends with a separator if (path[path.size()-1] != directory::get_separator()) path += directory::get_separator(); // get a handle to something we can search with ffind = opendir(state.full_name.c_str()); if (ffind == 0) { throw directory::listing_error("Unable to list the contents of " + state.full_name); } while(true) { errno = 0; if ( (data = readdir(ffind)) == 0) { // there was an error or no more files if ( errno == 0) { // there are no more files break; } else { // there was an error throw directory::listing_error("Unable to list the contents of " + state.full_name); } } // get a stat64 structure so we can see if this is a file if (::stat64((path+data->d_name).c_str(), &buffer) != 0) { // just assume this isn't a directory. It is probably a broken // symbolic link. continue; } std::string dtemp(data->d_name); if (S_ISDIR(buffer.st_mode) && dtemp != "." && dtemp != ".." ) { // this is a directory so add it to dirs directory temp(dtemp,path+dtemp, directory::private_constructor()); dirs.enqueue(temp); } } // while (true) if (ffind != 0) { while (closedir(ffind)) { if (errno != EINTR) break; } ffind = 0; } } catch (...) { if (ffind != 0) { while (closedir(ffind)) { if (errno != EINTR) break; } ffind = 0; } dirs.clear(); throw; } } // ---------------------------------------------------------------------------------------- template < typename queue_of_dirs > typename enable_if,void>::type directory_helper_get_dirs ( const directory::data& state, queue_of_dirs& dirs ) { queue::kernel_2a temp_dirs; directory_helper_get_dirs(state,temp_dirs); dirs.clear(); // copy the queue of dirs into the vector temp_dirs.reset(); while (temp_dirs.move_next()) { dirs.push_back(temp_dirs.element()); } } // ---------------------------------------------------------------------------------------- template < typename queue_of_dirs > void directory:: get_dirs ( queue_of_dirs& dirs ) const { // the reason for this indirection here is because it avoids a bug in // the cygwin version of gcc directory_helper_get_dirs(state,dirs); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename queue_of_dir > typename disable_if,void>::type get_filesystem_roots ( queue_of_dir& roots ) { roots.clear(); directory dir("/"); roots.enqueue(dir); } template < typename queue_of_dir > typename enable_if,void>::type get_filesystem_roots ( std::vector& roots ) { roots.clear(); directory dir("/"); roots.push_back(dir); } // ---------------------------------------------------------------------------------------- } #ifdef NO_MAKEFILE #include "dir_nav_kernel_2.cpp" #endif #endif // DLIB_DIR_NAV_KERNEl_2_ ================================================ FILE: dlib/dir_nav/dir_nav_kernel_abstract.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_DIR_NAV_KERNEl_ABSTRACT_ #ifdef DLIB_DIR_NAV_KERNEl_ABSTRACT_ #include #include #include "../uintn.h" #include "../algs.h" #include namespace dlib { // ---------------------------------------------------------------------------------------- /*! GENERAL WARNING Don't call any of these functions or make any of these objects before main() has been entered. That means no instances of file or directory at the global scope. !*/ // ---------------------------------------------------------------------------------------- template < typename queue_of_dir > void get_filesystem_roots ( queue_of_dir& roots ); /*! requires - queue_of_dirs == an implementation of queue/queue_kernel_abstract.h with T set to directory or a std::vector or dlib::std_vector_c. ensures - #roots == a queue containing directories that represent all the roots of the filesystem on this machine. (e.g. in windows you have c:\, d:\ etc.) throws - std::bad_alloc !*/ // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // file object // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- class file { /*! WHAT THIS OBJECT REPRESENTS This object represents a file. Note that the size of a file is determined at the time the file object is constructed. Thus if a file changes sizes after its file object has been created its file object's size() method will not reflect the new file size. !*/ public: class file_not_found : public error {}; file ( ); /*! ensures - #*this has been properly initialized - #name() == "" - #full_name() == "" - #size() == 0 - #*this does not represent any file throws - std::bad_alloc !*/ file ( const std::string& name ); /*! ensures - #*this has been properly initialized - #*this represents the file given by name Note that name can be a fully qualified path or just a path relative to the current working directory. Also, any symbolic links in name will be resolved. throws - std::bad_alloc - file_not_found This exception is thrown if the file can not be found or accessed. !*/ file ( const char* name ); /*! ensures - this function is identical to file(const std::string& name) !*/ file ( const file& item ); /*! ensures - #*this == item throws - std::bad_alloc !*/ ~file ( ); /*! ensures - all resources associated with *this have been released !*/ const std::string& name ( ) const; /*! ensures - returns the name of the file. This is full_name() minus the path to the file. !*/ const std::string& full_name ( ) const; /*! ensures - returns the fully qualified name for the file represented by *this !*/ uint64 size ( ) const; /*! ensures - returns the size of this file in bytes. !*/ std::chrono::time_point last_modified ( ) const; /*! ensures - returns the time the file was last modified. !*/ operator std::string ( ) const; /*! ensures - returns full_name() (i.e. provides an implicit conversion to string from dlib::file) !*/ file& operator= ( const file& rhs ); /*! ensures - #*this == rhs !*/ bool operator == ( const file& rhs ) const; /*! ensures - if (*this and rhs represent the same file) then - returns true - else - returns false !*/ bool operator != ( const file& rhs ) const; /*! ensures - if (*this and rhs represent the same file) then - returns false - else - returns true !*/ bool operator < ( const file& item ) const; /*! ensures - if (full_name() < item.full_name()) then - returns true - else - returns false !*/ void swap ( file& item ); /*! ensures - swaps *this and item !*/ }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // directory object // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- class directory { /*! WHAT THIS OBJECT REPRESENTS This object represents a directory in a file system. It gives the ability to traverse a directory tree. Note that the directories . and .. are not returned by get_dirs() !*/ public: class dir_not_found : public error {}; class listing_error : public error {}; directory ( ); /*! ensures - #*this has been properly initialized - #full_name() == "" - #name() == "" - #is_root() == true - #*this does not represent any directory throws - std::bad_alloc !*/ directory ( const std::string& name ); /*! ensures - #*this has been properly initialized - #*this represents the directory given by name. Note that name can be a fully qualified path or just a path relative to the current working directory. Also, any symbolic links in name will be resolved. throws - std::bad_alloc - dir_not_found This exception is thrown if the directory can not be found or accessed. !*/ directory ( const char* name ); /*! ensures - this function is identical to directory(const std::string& name) !*/ directory ( const directory& item ); /*! ensures - #*this == item throws - std::bad_alloc !*/ ~directory ( ); /*! ensures - all resources associated with *this have been released !*/ static char get_separator ( ); /*! ensures - returns the character used to separate directories and file names in a path. (i.e. \ on windows and / in unix) !*/ template < typename queue_of_files > void get_files ( queue_of_files& files ) const; /*! requires - queue_of_files == an implementation of queue/queue_kernel_abstract.h with T set to file or a std::vector or dlib::std_vector_c. ensures - #files == A queue containing all the files present in this directory. (Note that symbolic links will not have been resolved in the names of the returned files.) - #files.size() == the number of files in this directory throws - bad_alloc If this exception is thrown then the call to get_files() has no effect on *this and #files is unusable until files.clear() is called and succeeds. - listing_error This exception is thrown if listing access has been denied to this directory or if some error occurred that prevented us from successfully getting the contents of this directory. If this exception is thrown then the call to get_files() has no effect on *this and #files.size()==0. !*/ std::vector get_files ( ) const; /*! ensures - This function simply calls get_files(temp_vector) and then returns temp_vector. !*/ template < typename queue_of_dirs > void get_dirs ( queue_of_dirs& dirs ) const; /*! requires - queue_of_dirs == an implementation of queue/queue_kernel_abstract.h with T set to directory or a std::vector or dlib::std_vector_c. ensures - #dirs == a queue containing all the directories present in this directory. (note that symbolic links will not have been resolved in the names of the returned directories.) - #dirs.size() == the number of subdirectories in this directory throws - bad_alloc If this exception is thrown then the call to get_files() has no effect on *this and #files is unusable until files.clear() is called and succeeds. - listing_error This exception is thrown if listing access has been denied to this directory or if some error occurred that prevented us from successfully getting the contents of this directory. If this exception is thrown then the call to get_dirs() has no effect on *this and #dirs.size()==0. !*/ std::vector get_dirs ( ) const; /*! ensures - This function simply calls get_dirs(temp_vector) and then returns temp_vector. !*/ bool is_root ( ) const; /*! ensures - if (*this represents the root of this directory tree) then - returns true - else - returns false !*/ const directory get_parent ( ) const; /*! ensures - if (is_root()) then - returns a copy of *this - else - returns the parent directory of *this throws - bad_alloc If this exception is thrown then the call to get_parent() will have no effect. !*/ const std::string& name ( ) const; /*! ensures - if (is_root()) then - returns "" - else - returns the name of the directory. This is full_name() minus the path to the directory. !*/ const std::string& full_name ( ) const; /*! ensures - returns the fully qualified directory name for *this - if (is_root()) then - the last character of #full_name() is get_separator() - else - the last character of #full_name() is NOT get_separator() !*/ operator std::string ( ) const; /*! ensures - returns full_name() (i.e. provides an implicit conversion to string from dlib::directory) !*/ directory& operator= ( const directory& rhs ); /*! ensures - #*this == rhs !*/ bool operator == ( const directory& rhs ) const; /*! ensures - if (*this and rhs represent the same directory) then - returns true - else - returns false !*/ bool operator != ( const directory& rhs ) const; /*! ensures - if (*this and rhs represent the same directory) then - returns false - else - returns true !*/ bool operator < ( const directory& item ) const; /*! ensures - if (full_name() < item.full_name()) then - returns true - else - returns false !*/ void swap ( directory& item ); /*! ensures - swaps *this and item !*/ }; // ---------------------------------------------------------------------------------------- inline std::ostream& operator<< ( std::ostream& out, const directory& item ); /*! ensures - performs: out << item.full_name() - returns out !*/ inline std::ostream& operator<< ( std::ostream& out, const file& item ); /*! ensures - performs: out << item.full_name() - returns out !*/ // ---------------------------------------------------------------------------------------- inline void swap ( file& a, file& b ) { a.swap(b); } /*! provides a global swap function for file objects !*/ // ---------------------------------------------------------------------------------------- inline void swap ( directory& a, directory& b ) { a.swap(b); } /*! provides a global swap function for directory objects !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_DIR_NAV_KERNEl_ABSTRACT_ ================================================ FILE: dlib/dir_nav/posix.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DIR_NAV_KERNEl_1_ #include "dir_nav_kernel_2.h" #endif ================================================ FILE: dlib/dir_nav/windows.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DIR_NAV_KERNEl_2_ #include "dir_nav_kernel_1.h" #endif ================================================ FILE: dlib/dir_nav.h ================================================ // Copyright (C) 2003 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DIR_NAv_ #define DLIB_DIR_NAv_ #include "platform.h" #ifdef WIN32 #include "dir_nav/windows.h" #endif #ifndef WIN32 #include "dir_nav/posix.h" #endif #include "dir_nav/dir_nav_extensions.h" #endif // DLIB_DIR_NAv_ ================================================ FILE: dlib/directed_graph/directed_graph_kernel_1.h ================================================ // Copyright (C) 2007 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DIRECTED_GRAPH_KERNEl_1_ #define DLIB_DIRECTED_GRAPH_KERNEl_1_ #include #include #include "../serialize.h" #include "../noncopyable.h" #include "../std_allocator.h" #include "../algs.h" #include "directed_graph_kernel_abstract.h" #include "../is_kind.h" namespace dlib { // ---------------------------------------------------------------------------------------- template struct directed_graph_checker_helper { /*! This object is used to check preconditions based on the value of is_checked !*/ static void check_parent_edge ( unsigned long edge_index, const node_type& self ) { // make sure requires clause is not broken DLIB_CASSERT(edge_index < self.number_of_parents(), "\tnode_type& directed_graph::node_type::parent_edge(edge_index)" << "\n\tYou have specified an invalid index" << "\n\tedge_index: " << edge_index << "\n\tnumber_of_parents(): " << self.number_of_parents() << "\n\tthis: " << &self ); } static void check_child_edge ( unsigned long edge_index, const node_type& self ) { // make sure requires clause is not broken DLIB_CASSERT(edge_index < self.number_of_children(), "\tnode_type& directed_graph::node_type::child_edge(edge_index)" << "\n\tYou have specified an invalid index" << "\n\tedge_index: " << edge_index << "\n\tnumber_of_children(): " << self.number_of_children() << "\n\tthis: " << &self ); } static void check_parent ( unsigned long edge_index, const node_type& self ) { // make sure requires clause is not broken DLIB_CASSERT(edge_index < self.number_of_parents(), "\tnode_type& directed_graph::node_type::parent(edge_index)" << "\n\tYou have specified an invalid index" << "\n\tedge_index: " << edge_index << "\n\tnumber_of_parents(): " << self.number_of_parents() << "\n\tthis: " << &self ); } static void check_child ( unsigned long edge_index, const node_type& self ) { // make sure requires clause is not broken DLIB_CASSERT(edge_index < self.number_of_children(), "\tnode_type& directed_graph::node_type::child(edge_index)" << "\n\tYou have specified an invalid index" << "\n\tedge_index: " << edge_index << "\n\tnumber_of_children(): " << self.number_of_children() << "\n\tthis: " << &self ); } static void check_node ( unsigned long index, const directed_graph& self ) { // make sure requires clause is not broken DLIB_CASSERT(index < self.number_of_nodes(), "\tnode_type& directed_graph::node(index)" << "\n\tYou have specified an invalid index" << "\n\tindex: " << index << "\n\tnumber_of_nodes(): " << self.number_of_nodes() << "\n\tthis: " << &self ); } static void check_has_edge ( unsigned long parent_node_index, unsigned long child_node_index, const directed_graph& self ) { // make sure requires clause is not broken DLIB_CASSERT(parent_node_index < self.number_of_nodes() && child_node_index < self.number_of_nodes(), "\tvoid directed_graph::has_edge(parent_node_index, child_node_index)" << "\n\tYou have specified an invalid index" << "\n\tparent_node_index: " << parent_node_index << "\n\tchild_node_index: " << child_node_index << "\n\tnumber_of_nodes(): " << self.number_of_nodes() << "\n\tthis: " << &self ); } static void check_add_edge ( unsigned long parent_node_index, unsigned long child_node_index, const directed_graph& self ) { DLIB_CASSERT(parent_node_index < self.number_of_nodes() && child_node_index < self.number_of_nodes(), "\tvoid directed_graph::add_edge(parent_node_index, child_node_index)" << "\n\tYou have specified an invalid index" << "\n\tparent_node_index: " << parent_node_index << "\n\tchild_node_index: " << child_node_index << "\n\tnumber_of_nodes(): " << self.number_of_nodes() << "\n\tthis: " << &self ); DLIB_CASSERT( self.has_edge(parent_node_index, child_node_index) == false, "\tvoid directed_graph::add_edge(parent_node_index, child_node_index)" << "\n\tYou can't add an edge if it already exists in the graph" << "\n\tparent_node_index: " << parent_node_index << "\n\tchild_node_index: " << child_node_index << "\n\tnumber_of_nodes(): " << self.number_of_nodes() << "\n\tthis: " << &self ); } static void check_remove_edge ( unsigned long parent_node_index, unsigned long child_node_index, const directed_graph& self ) { DLIB_CASSERT(parent_node_index < self.number_of_nodes() && child_node_index < self.number_of_nodes(), "\tvoid directed_graph::remove_edge(parent_node_index, child_node_index)" << "\n\tYou have specified an invalid index" << "\n\tparent_node_index: " << parent_node_index << "\n\tchild_node_index: " << child_node_index << "\n\tnumber_of_nodes(): " << self.number_of_nodes() << "\n\tthis: " << &self ); DLIB_CASSERT( self.has_edge(parent_node_index, child_node_index) == true, "\tvoid directed_graph::remove_edge(parent_node_index, child_node_index)" << "\n\tYou can't remove an edge if it isn't in the graph" << "\n\tparent_node_index: " << parent_node_index << "\n\tchild_node_index: " << child_node_index << "\n\tnumber_of_nodes(): " << self.number_of_nodes() << "\n\tthis: " << &self ); } static void check_remove_node ( unsigned long index, const directed_graph& self ) { // make sure requires clause is not broken DLIB_CASSERT(index < self.number_of_nodes(), "\tvoid directed_graph::remove_node(index)" << "\n\tYou have specified an invalid index" << "\n\tindex: " << index << "\n\tnumber_of_nodes(): " << self.number_of_nodes() << "\n\tthis: " << &self ); } }; template struct directed_graph_checker_helper { static inline void check_parent ( unsigned long , const node_type&) { } static inline void check_child ( unsigned long , const node_type& ) { } static inline void check_parent_edge ( unsigned long , const node_type&) { } static inline void check_child_edge ( unsigned long , const node_type& ) { } static inline void check_node ( unsigned long , const directed_graph& ) { } static inline void check_has_edge ( unsigned long , unsigned long , const directed_graph& ) { } static inline void check_add_edge ( unsigned long , unsigned long , const directed_graph& ) { } static inline void check_remove_edge ( unsigned long , unsigned long , const directed_graph& ) { } static inline void check_remove_node ( unsigned long , const directed_graph& ) { } }; // ---------------------------------------------------------------------------------------- template < typename T, typename E = char, typename mem_manager = default_memory_manager, bool is_checked = true > class directed_graph_kernel_1 : noncopyable { /*! INITIAL VALUE - nodes.size() == 0 CONVENTION - nodes.size() == number_of_nodes() - for all valid i: - *nodes[i] == node(i) - nodes[i]->parents.size() == nodes[i]->number_of_parents(i) - nodes[i]->children.size() == nodes[i]->number_of_children(i) - nodes[i]->edge_parents.size() == nodes[i]->number_of_parents(i) - nodes[i]->edge_children.size() == nodes[i]->number_of_children(i) - nodes[i]->idx == i == nodes[i]->index() - for all valid p: - nodes[i]->parents[p] == pointer to the p'th parent node of i - *nodes[i]->parents[p] == nodes[i]->parent(p) - *nodes[i]->edge_parents[p] == nodes[i]->parent_edge(p) - for all valid c: - nodes[i]->children[c] == pointer to the c'th child node of i - *nodes[i]->children[c] == nodes[i]->child(c) - *nodes[i]->edge_children[c] == nodes[i]->child_edge(c) !*/ public: struct node_type; private: typedef directed_graph_checker_helper checker; public: typedef T type; typedef E edge_type; typedef mem_manager mem_manager_type; template struct rebind { typedef directed_graph_kernel_1 other; }; directed_graph_kernel_1( ) {} virtual ~directed_graph_kernel_1( ) {} void clear( ) { nodes.clear(); } void set_number_of_nodes ( unsigned long new_size ); unsigned long number_of_nodes ( ) const { return nodes.size(); } node_type& node ( unsigned long index ) { checker::check_node(index,*this); return *nodes[index]; } const node_type& node ( unsigned long index ) const { checker::check_node(index,*this); return *nodes[index]; } bool has_edge ( unsigned long parent_node_index, unsigned long child_node_index ) const; void add_edge ( unsigned long parent_node_index, unsigned long child_node_index ); void remove_edge ( unsigned long parent_node_index, unsigned long child_node_index ); unsigned long add_node ( ); void remove_node ( unsigned long index ); void swap ( directed_graph_kernel_1& item ) { nodes.swap(item.nodes); } private: public: struct node_type { T data; typedef directed_graph_kernel_1 graph_type; unsigned long index( ) const { return idx; } unsigned long number_of_parents ( ) const { return parents.size(); } unsigned long number_of_children ( ) const { return children.size(); } const node_type& parent ( unsigned long edge_index ) const { checker::check_parent(edge_index,*this); return *parents[edge_index]; } node_type& parent ( unsigned long edge_index ) { checker::check_parent(edge_index,*this); return *parents[edge_index]; } const node_type& child ( unsigned long edge_index ) const { checker::check_child(edge_index,*this); return *children[edge_index]; } node_type& child ( unsigned long edge_index ) { checker::check_child(edge_index,*this); return *children[edge_index]; } const E& parent_edge ( unsigned long edge_index ) const { checker::check_parent_edge(edge_index,*this); return *edge_parents[edge_index]; } E& parent_edge ( unsigned long edge_index ) { checker::check_parent_edge(edge_index,*this); return *edge_parents[edge_index]; } const E& child_edge ( unsigned long edge_index ) const { checker::check_child_edge(edge_index,*this); return *edge_children[edge_index]; } E& child_edge ( unsigned long edge_index ) { checker::check_child_edge(edge_index,*this); return *edge_children[edge_index]; } private: friend class directed_graph_kernel_1; typedef std_allocator alloc_type; typedef std_allocator,mem_manager> alloc_edge_type; std::vector parents; std::vector children; std::vector,alloc_edge_type> edge_parents; std::vector,alloc_edge_type> edge_children; unsigned long idx; }; private: typedef std_allocator,mem_manager> alloc_type; typedef std::vector, alloc_type> vector_type; vector_type nodes; }; // ---------------------------------------------------------------------------------------- template < typename T, typename E, typename mem_manager, bool is_checked > struct is_directed_graph > { static const bool value = true; }; // ---------------------------------------------------------------------------------------- template < typename T, typename E, typename mem_manager, bool is_checked > inline void swap ( directed_graph_kernel_1& a, directed_graph_kernel_1& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- template < typename T, typename E, typename mem_manager, bool is_checked > void serialize ( const directed_graph_kernel_1& item, std::ostream& out ) { try { serialize(item.number_of_nodes(), out); // serialize each node for (unsigned long i = 0; i < item.number_of_nodes(); ++i) { serialize(item.node(i).data, out); // serialize all the child edges serialize(item.node(i).number_of_children(), out); for (unsigned long c = 0; c < item.node(i).number_of_children(); ++c) { serialize(item.node(i).child(c).index(), out); serialize(item.node(i).child_edge(c), out); } } } catch (serialization_error& e) { throw serialization_error(e.info + "\n while serializing object of type directed_graph_kernel_1"); } } // ---------------------------------------------------------------------------------------- template < typename T, typename E, typename mem_manager, bool is_checked > void deserialize ( directed_graph_kernel_1& item, std::istream& in ) { try { unsigned long size; deserialize(size, in); item.clear(); item.set_number_of_nodes(size); // deserialize each node for (unsigned long i = 0; i < item.number_of_nodes(); ++i) { deserialize(item.node(i).data, in); unsigned long num_children; deserialize(num_children, in); // Add all the edges going to this nodes children nodes for (unsigned long c = 0; c < num_children; ++c) { unsigned long child_index; deserialize(child_index, in); item.add_edge(i, child_index); // find the edge we just added for (unsigned long j = 0; j < item.node(i).number_of_children(); ++j) { if (item.node(i).child(j).index() == child_index) { deserialize(item.node(i).child_edge(j), in); break; } } } } } catch (serialization_error& e) { throw serialization_error(e.info + "\n while deserializing object of type directed_graph_kernel_1"); } } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // member function definitions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename T, typename E, typename mem_manager, bool is_checked > void directed_graph_kernel_1:: set_number_of_nodes ( unsigned long new_size ) { try { nodes.resize(new_size); for (unsigned long i = 0; i < nodes.size(); ++i) { nodes[i].reset(new node_type); nodes[i]->idx = i; } } catch (...) { clear(); throw; } } // ---------------------------------------------------------------------------------------- template < typename T, typename E, typename mem_manager, bool is_checked > bool directed_graph_kernel_1:: has_edge ( unsigned long parent_node_index, unsigned long child_node_index ) const { checker::check_has_edge(parent_node_index, child_node_index, *this); node_type& n = *nodes[parent_node_index]; // search all the child nodes to see if there is a link to the right node for (unsigned long i = 0; i < n.children.size(); ++i) { if (n.children[i]->idx == child_node_index) return true; } return false; } // ---------------------------------------------------------------------------------------- template < typename T, typename E, typename mem_manager, bool is_checked > void directed_graph_kernel_1:: add_edge ( unsigned long parent_node_index, unsigned long child_node_index ) { checker::check_add_edge(parent_node_index, child_node_index, *this); try { node_type& p = *nodes[parent_node_index]; node_type& c = *nodes[child_node_index]; p.children.push_back(&c); c.parents.push_back(&p); p.edge_children.push_back(std::shared_ptr(new E)); c.edge_parents.push_back(p.edge_children.back()); } catch (...) { clear(); throw; } } // ---------------------------------------------------------------------------------------- template < typename T, typename E, typename mem_manager, bool is_checked > void directed_graph_kernel_1:: remove_edge ( unsigned long parent_node_index, unsigned long child_node_index ) { checker::check_remove_edge(parent_node_index, child_node_index, *this); node_type& p = *nodes[parent_node_index]; node_type& c = *nodes[child_node_index]; // remove the record of the link from the parent node unsigned long pos = static_cast(find( p.children.begin(), p.children.end(), &c) - p.children.begin()); p.children.erase(p.children.begin()+pos); p.edge_children.erase(p.edge_children.begin()+pos); // remove the record of the link from the child node pos = static_cast(find( c.parents.begin(), c.parents.end(), &p) - c.parents.begin()); c.parents.erase(c.parents.begin() + pos); c.edge_parents.erase(c.edge_parents.begin() + pos); } // ---------------------------------------------------------------------------------------- template < typename T, typename E, typename mem_manager, bool is_checked > unsigned long directed_graph_kernel_1:: add_node ( ) { try { std::shared_ptr n(new node_type); n->idx = nodes.size(); nodes.push_back(n); return n->idx; } catch (...) { clear(); throw; } } // ---------------------------------------------------------------------------------------- template < typename T, typename E, typename mem_manager, bool is_checked > void directed_graph_kernel_1:: remove_node ( unsigned long index ) { checker::check_remove_node(index,*this); node_type& n = *nodes[index]; // remove all edges pointing to this node from its parents for (unsigned long i = 0; i < n.parents.size(); ++i) { // remove the edge from this specific parent unsigned long pos = static_cast(find(n.parents[i]->children.begin(), n.parents[i]->children.end(), &n) - n.parents[i]->children.begin()); n.parents[i]->children.erase(n.parents[i]->children.begin() + pos); n.parents[i]->edge_children.erase(n.parents[i]->edge_children.begin() + pos); } // remove all edges pointing to this node from its children for (unsigned long i = 0; i < n.children.size(); ++i) { // remove the edge from this specific child unsigned long pos = static_cast(find(n.children[i]->parents.begin(), n.children[i]->parents.end(), &n) - n.children[i]->parents.begin()); n.children[i]->parents.erase(n.children[i]->parents.begin() + pos); n.children[i]->edge_parents.erase(n.children[i]->edge_parents.begin() + pos); } // now remove this node by replacing it with the last node in the nodes vector nodes[index] = nodes[nodes.size()-1]; // update the index for the node we just moved nodes[index]->idx = index; // now remove the duplicated node at the end of the vector nodes.pop_back(); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_DIRECTED_GRAPH_KERNEl_1_ ================================================ FILE: dlib/directed_graph/directed_graph_kernel_abstract.h ================================================ // Copyright (C) 2007 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_DIRECTED_GRAPH_KERNEl_ABSTRACT_ #ifdef DLIB_DIRECTED_GRAPH_KERNEl_ABSTRACT_ #include "../serialize.h" #include "../algs.h" #include "../noncopyable.h" namespace dlib { template < typename T, typename E = char, typename mem_manager = default_memory_manager > class directed_graph : noncopyable { /*! REQUIREMENTS ON T T must be swappable by a global swap() and T must have a default constructor REQUIREMENTS ON E E must be swappable by a global swap() and E must have a default constructor REQUIREMENTS ON mem_manager must be an implementation of memory_manager/memory_manager_kernel_abstract.h or must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h mem_manager::type can be set to anything. POINTERS AND REFERENCES TO INTERNAL DATA The only time pointers or references to nodes or edges become invalid is when they reference nodes or edges that have been removed from a graph. INITIAL VALUE number_of_nodes() == 0 WHAT THIS OBJECT REPRESENTS This object represents a directed graph which is a set of nodes with directed edges connecting various nodes. In this object if there is a directed edge from a node A to a node B then I say that A is the parent of B and B is the child of A. Also note that unless specified otherwise, no member functions of this object throw exceptions. !*/ public: typedef T type; typedef E edge_type; typedef mem_manager mem_manager_type; template struct rebind { typedef directed_graph other; }; directed_graph( ); /*! ensures - #*this is properly initialized throws - std::bad_alloc or any exception thrown by T's constructor. !*/ virtual ~directed_graph( ); /*! ensures - all resources associated with *this has been released !*/ void clear( ); /*! ensures - #*this has its initial value throws - std::bad_alloc or any exception thrown by T's constructor. If this exception is thrown then *this is unusable until clear() is called and succeeds !*/ void set_number_of_nodes ( unsigned long new_size ); /*! ensures - #number_of_nodes() == new_size - for all i < new_size: - number_of_parents(i) == 0 - number_of_children(i) == 0 throws - std::bad_alloc or any exception thrown by T's constructor. If this exception is thrown then this object reverts back to its initial state. !*/ unsigned long number_of_nodes ( ) const; /*! ensures - returns the number of nodes in this graph !*/ struct node_type { T data; typedef directed_graph graph_type; unsigned long index( ) const; /*! ensures - let G be the graph that contains the node *this - returns a number N such that G.node(N) == *this (i.e. returns the index of this node in the graph) !*/ unsigned long number_of_parents ( ) const; /*! ensures - returns the number of parents of this node !*/ unsigned long number_of_children ( ) const; /*! ensures - returns the number of children of this node !*/ const node_type& parent ( unsigned long edge_index ) const; /*! requires - edge_index < number_of_parents() ensures - returns a const reference to the edge_index'th parent of *this !*/ node_type& parent ( unsigned long edge_index ); /*! requires - edge_index < number_of_parents() ensures - returns a non-const reference to the edge_index'th parent of *this !*/ const node_type& child ( unsigned long edge_index ) const; /*! requires - edge_index < number_of_children() ensures - returns a const reference to the edge_index'th child of *this !*/ node_type& child ( unsigned long edge_index ); /*! requires - edge_index < number_of_children() ensures - returns a non-const reference to the edge_index'th child of *this !*/ const E& parent_edge ( unsigned long edge_index ) const; /*! requires - edge_index < number_of_parents() ensures - returns a const reference to the edge_index'th edge data for the edge connecting to node this->parent(edge_index) !*/ E& parent_edge ( unsigned long edge_index ); /*! requires - edge_index < number_of_parents() ensures - returns a non-const reference to the edge_index'th edge data for the edge connecting to node this->parent(edge_index) !*/ const E& child_edge ( unsigned long edge_index ) const; /*! requires - edge_index < number_of_children() ensures - returns a const reference to the edge_index'th edge data for the edge connecting to node this->child(edge_index) !*/ E& child_edge ( unsigned long edge_index ); /*! requires - edge_index < number_of_children() ensures - returns a non-const reference to the edge_index'th edge data for the edge connecting to node this->child(edge_index) !*/ }; node_type& node ( unsigned long index ); /*! requires - index < number_of_nodes() ensures - returns a non-const reference to the node with the given index !*/ const node_type& node ( unsigned long index ) const; /*! requires - index < number_of_nodes() ensures - returns a const reference to the node with the given index !*/ bool has_edge ( unsigned long parent_node_index, unsigned long child_node_index ) const; /*! requires - parent_node_index < number_of_nodes() - child_node_index < number_of_nodes() ensures - if (there is an edge leading from node(parent_node_index) to node(child_node_index)) then - returns true - else - returns false !*/ void add_edge ( unsigned long parent_node_index, unsigned long child_node_index ); /*! requires - parent_node_index < number_of_nodes() - child_node_index < number_of_nodes() - has_edge(parent_node_index, child_node_index) == false ensures - #has_edge(parent_node_index, child_node_index) == true throws - std::bad_alloc If this exception is thrown then this object reverts back to its initial state. !*/ void remove_edge ( unsigned long parent_node_index, unsigned long child_node_index ); /*! requires - parent_node_index < number_of_nodes() - child_node_index < number_of_nodes() - has_edge(parent_node_index, child_node_index) == true ensures - #has_edge(parent_node_index, child_node_index) == false throws - std::bad_alloc If this exception is thrown then this object reverts back to its initial state. !*/ unsigned long add_node ( ); /*! ensures - does not change the index number of existing nodes - adds a node with index N == number_of_nodes() such that: - #node(N).number_of_parents() == 0 - #node(N).number_of_children() == 0 - #number_of_nodes() == number_of_nodes() + 1 - returns N throws - std::bad_alloc or any exception thrown by T's constructor. If this exception is thrown then this object reverts back to its initial state. !*/ void remove_node ( unsigned long index ); /*! requires - index < number_of_nodes() ensures - removes the node with the given index from the graph. - removes all edges linking the removed node to the rest of the graph. - the remaining node indexes are remapped so that they remain contiguous. (This means that for all valid N, node(N) doesn't necessarily reference the same node as #node(N)) - #number_of_nodes() == number_of_nodes() - 1 throws - std::bad_alloc or any exception thrown by T's constructor. If this exception is thrown then this object reverts back to its initial state. !*/ void swap ( directed_graph& item ); /*! ensures - swaps *this and item !*/ }; template < typename T, typename mem_manager > inline void swap ( directed_graph& a, directed_graph& b ) { a.swap(b); } /*! provides a global swap function !*/ template < typename T, typename mem_manager > void serialize ( const directed_graph& item, std::ostream& out ); /*! provides deserialization support !*/ template < typename T, typename mem_manager > void deserialize ( directed_graph& item, std::istream& in ); /*! provides deserialization support !*/ } #endif // DLIB_DIRECTED_GRAPH_KERNEl_ABSTRACT_ ================================================ FILE: dlib/directed_graph.h ================================================ // Copyright (C) 2007 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DIRECTED_GRAPh_ #define DLIB_DIRECTED_GRAPh_ #include "directed_graph/directed_graph_kernel_1.h" #include "algs.h" namespace dlib { template < typename T, typename E = char, typename mem_manager = default_memory_manager > class directed_graph { directed_graph() {} public: //----------- kernels --------------- // kernel_1a typedef directed_graph_kernel_1 kernel_1a; typedef directed_graph_kernel_1 kernel_1a_c; }; } #endif // DLIB_DIRECTED_GRAPh_ ================================================ FILE: dlib/disjoint_subsets/disjoint_subsets.h ================================================ // Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DISJOINT_SUBsETS_Hh_ #define DLIB_DISJOINT_SUBsETS_Hh_ #include "disjoint_subsets_abstract.h" #include #include "../algs.h" namespace dlib { // ---------------------------------------------------------------------------------------- class disjoint_subsets { public: void clear ( ) noexcept { items.clear(); } void set_size ( unsigned long new_size ) { items.resize(new_size); for (unsigned long i = 0; i < items.size(); ++i) { items[i].parent = i; items[i].rank = 0; } } size_t size ( ) const noexcept { return items.size(); } unsigned long find_set ( unsigned long item ) const { // make sure requires clause is not broken DLIB_ASSERT(item < size(), "\t unsigned long disjoint_subsets::find_set()" << "\n\t item must be less than size()" << "\n\t item: " << item << "\n\t size(): " << size() << "\n\t this: " << this ); if (items[item].parent == item) { return item; } else { // find root of item unsigned long x = item; do { x = items[x].parent; } while (items[x].parent != x); // do path compression const unsigned long root = x; x = item; while (items[x].parent != x) { const unsigned long prev = x; x = items[x].parent; items[prev].parent = root; } return root; } } unsigned long merge_sets ( unsigned long a, unsigned long b ) { // make sure requires clause is not broken DLIB_ASSERT(a != b && a < size() && b < size() && find_set(a) == a && find_set(b) == b, "\t unsigned long disjoint_subsets::merge_sets(a,b)" << "\n\t invalid arguments were given to this function" << "\n\t a: " << a << "\n\t b: " << b << "\n\t size(): " << size() << "\n\t find_set(a): " << find_set(a) << "\n\t find_set(b): " << find_set(b) << "\n\t this: " << this ); if (items[a].rank > items[b].rank) { items[b].parent = a; return a; } else { items[a].parent = b; if (items[a].rank == items[b].rank) { items[b].rank = items[b].rank + 1; } return b; } } private: /* See the book Introduction to Algorithms by Cormen, Leiserson, Rivest and Stein for a discussion of how this algorithm works. */ struct data { unsigned long rank; unsigned long parent; }; mutable std::vector items; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_DISJOINT_SUBsETS_Hh_ ================================================ FILE: dlib/disjoint_subsets/disjoint_subsets_abstract.h ================================================ // Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_ #ifdef DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_ #include #include "../algs.h" namespace dlib { // ---------------------------------------------------------------------------------------- class disjoint_subsets { /*! INITIAL VALUE - size() == 0 WHAT THIS OBJECT REPRESENTS This object represents a set of integers which is partitioned into a number of disjoint subsets. It supports the two fundamental operations of finding which subset a particular integer belongs to as well as merging subsets. !*/ public: void clear ( ) noexcept; /*! ensures - #size() == 0 - returns this object to its initial value !*/ void set_size ( unsigned long new_size ); /*! ensures - #size() == new_size - for all valid i: - #find_set(i) == i (i.e. this object contains new_size subsets, each containing exactly one element) !*/ size_t size ( ) const noexcept; /*! ensures - returns the total number of integer elements represented by this object. !*/ unsigned long find_set ( unsigned long item ) const; /*! requires - item < size() ensures - Each disjoint subset can be represented by any of its elements (since the sets are all disjoint). In particular, for each subset we define a special "representative element" which is used to represent it. Therefore, this function returns the representative element for the set which contains item. - find_set(find_set(item)) == find_set(item) - Note that if A and B are both elements of the same subset then we always have find_set(A) == find_set(B). !*/ unsigned long merge_sets ( unsigned long a, unsigned long b ); /*! requires - a != b - a < size() - b < size() - find_set(a) == a (i.e. a is the representative element of some set) - find_set(b) == b (i.e. b is the representative element of some set) ensures - #find_set(a) == #find_set(b) (i.e. merges the set's containing a and b) - returns #find_set(a) !*/ }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_ ================================================ FILE: dlib/disjoint_subsets/disjoint_subsets_sized.h ================================================ // Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DISJOINT_SUBsETS_SIZED_Hh_ #define DLIB_DISJOINT_SUBsETS_SIZED_Hh_ #include "disjoint_subsets_sized_abstract.h" #include "disjoint_subsets.h" #include #include "../algs.h" namespace dlib { // ---------------------------------------------------------------------------------------- class disjoint_subsets_sized { public: void clear ( ) noexcept { disjoint_subsets_.clear(); sets_size.clear(); number_of_sets = 0; } void set_size ( unsigned long new_size ) { disjoint_subsets_.set_size(new_size); sets_size.assign(new_size, 1); number_of_sets = new_size; } size_t size ( ) const noexcept { return disjoint_subsets_.size(); } unsigned long find_set ( unsigned long item ) const { // make sure requires clause is not broken DLIB_ASSERT(item < size(), "\t unsigned long disjoint_subsets::find_set()" << "\n\t item must be less than size()" << "\n\t item: " << item << "\n\t size(): " << size() << "\n\t this: " << this ); return disjoint_subsets_.find_set(item); } unsigned long merge_sets ( unsigned long a, unsigned long b ) { // make sure requires clause is not broken DLIB_ASSERT(a != b && a < size() && b < size() && find_set(a) == a && find_set(b) == b, "\t unsigned long disjoint_subsets::merge_sets(a,b)" << "\n\t invalid arguments were given to this function" << "\n\t a: " << a << "\n\t b: " << b << "\n\t size(): " << size() << "\n\t find_set(a): " << find_set(a) << "\n\t find_set(b): " << find_set(b) << "\n\t this: " << this ); disjoint_subsets_.merge_sets(a, b); if (find_set(a) == a) sets_size[a] += sets_size[b]; else sets_size[b] += sets_size[a]; --number_of_sets; return find_set(a); } unsigned long get_number_of_sets ( ) const noexcept { return number_of_sets; } unsigned long get_size_of_set( unsigned long item ) const { // make sure requires clause is not broken DLIB_ASSERT(item < size() && find_set(item) == item, "\t unsigned long disjoint_subsets::get_size_of_set()" << "\n\t invalid arguments were given to this function" << "\n\t item: " << item << "\n\t size(): " << size() << "\n\t find_set(item): " << find_set(item) << "\n\t this: " << this ); return sets_size[item]; } private: /* See the book Introduction to Algorithms by Cormen, Leiserson, Rivest and Stein for a discussion of how this algorithm works. */ mutable std::vector sets_size; unsigned long number_of_sets{0}; disjoint_subsets disjoint_subsets_; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_DISJOINT_SUBsETS_SIZED_Hh_ ================================================ FILE: dlib/disjoint_subsets/disjoint_subsets_sized_abstract.h ================================================ // Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_DISJOINT_SUBsETS_SIZED_ABSTRACT_Hh_ #ifdef DLIB_DISJOINT_SUBsETS_SIZED_ABSTRACT_Hh_ #include #include "../algs.h" namespace dlib { // ---------------------------------------------------------------------------------------- class disjoint_subsets_sized { /*! INITIAL VALUE - size() == 0 - get_number_of_sets() == 0 WHAT THIS OBJECT REPRESENTS This object represents a set of integers which is partitioned into a number of disjoint subsets. It supports the two fundamental operations of finding which subset a particular integer belongs to as well as merging subsets. It also allows you to find out how big each subset is. It is therefore essentially the same thing as dlib::disjoint_subsets, except it also keeps track of the size of each subset. !*/ public: void clear ( ) noexcept; /*! ensures - #size() == 0 - #get_number_of_sets() == 0 - returns this object to its initial value !*/ void set_size ( unsigned long new_size ); /*! ensures - #size() == new_size - #get_number_of_sets() == new_size - for all valid i: - #find_set(i) == i (i.e. this object contains new_size subsets, each containing exactly one element) - #get_size_of_set(i) == 1 !*/ size_t size ( ) const noexcept; /*! ensures - returns the total number of integer elements represented by this object. !*/ unsigned long find_set ( unsigned long item ) const; /*! requires - item < size() ensures - Each disjoint subset can be represented by any of its elements (since the sets are all disjoint). In particular, for each subset we define a special "representative element" which is used to represent it. Therefore, this function returns the representative element for the set which contains item. - find_set(find_set(item)) == find_set(item) - Note that if A and B are both elements of the same subset then we always have find_set(A) == find_set(B). !*/ unsigned long merge_sets ( unsigned long a, unsigned long b ); /*! requires - a != b - a < size() - b < size() - find_set(a) == a (i.e. a is the representative element of some set) - find_set(b) == b (i.e. b is the representative element of some set) ensures - #find_set(a) == #find_set(b) (i.e. merges the set's containing a and b) - #get_size_of_set(#find_set(a)) == get_size_of_set(a) + get_size_of_set(b) - #get_number_of_sets() == get_number_of_sets() - 1 - returns #find_set(a) !*/ unsigned long get_number_of_sets ( ) const noexcept; /*! ensures - returns the current number of different subsets. !*/ unsigned long get_size_of_set( unsigned long item ) const; /*! requires - item < size() - find_set(item) == item (i.e. item is the representative element of some set) ensures - returns the number of elements which belongs to the set where item is the representative element. !*/ }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_DISJOINT_SUBsETS_ABSTRACT_Hh_ ================================================ FILE: dlib/disjoint_subsets.h ================================================ // Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DISJOINt_SUBSETS_ #define DLIB_DISJOINt_SUBSETS_ #include "disjoint_subsets/disjoint_subsets.h" #include "disjoint_subsets/disjoint_subsets_sized.h" #endif // DLIB_DISJOINt_SUBSETS_ ================================================ FILE: dlib/dlib_basic_cpp_build_tutorial.txt ================================================ #error "Don't write #include in your code." /* In C++, it is generally an error to #include .cpp files. This is because it can lead to what are called multiply defined symbol errors. Therefore, you should compile dlib/all/source.cpp into your application just like you would compile any other .cpp file. If you are using Visual Studio you add .cpp files to your application using the solution explorer window. Specifically, right click on Source Files, then select Add -> Existing Item and select the .cpp files you want to add. For general information on compiling dlib see http://dlib.net/compile.html */ ================================================ FILE: dlib/dlib_include_path_tutorial.txt ================================================ #error "Don't put the dlib folder in your include path" /* You are getting this error because you have added the dlib folder to your compiler's include search path. You should *NOT* add the dlib folder itself to your compiler's include path. Doing so will cause the build to fail because of name collisions (such as dlib/string.h and string.h from the standard library). Instead you should add the folder that contains the dlib folder to your include search path and then use include statements of the form #include or #include "dlib/queue.h". This will ensure that everything builds correctly. XCode: The XCode IDE often puts all folders that it knows about into the compiler search path. So if you are using XCode then either don't drag the whole dlib folder into the project or alternatively modify your XCode project settings to not auto-add all folders to the include path. Instead just make sure that the dlib folder is itself inside a folder in your include path. */ ================================================ FILE: dlib/dnn/core.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNn_CORE_H_ #define DLIB_DNn_CORE_H_ #include #include #include #include #include #include #include #include #include "core_abstract.h" #include "../cuda/tensor.h" #include "../cuda/tensor_tools.h" #include "../statistics.h" #include "../rand.h" #include "../algs.h" #include "../metaprogramming.h" #include "../utility.h" #include "../constexpr_if.h" #ifdef _MSC_VER // Tell Visual Studio not to recursively inline functions very much because otherwise it // takes hours to compile the DNN code sometimes. It's crazy. Hopefully we can remove // this some day when the visual studio compiler is more efficient. #pragma inline_depth(2) #endif namespace dlib { // ---------------------------------------------------------------------------------------- namespace impl { template using has_get_learning_rate_multiplier = decltype(std::declval().get_learning_rate_multiplier()); template using has_set_learning_rate_multiplier = decltype(std::declval().set_learning_rate_multiplier(double{})); template using has_get_bias_learning_rate_multiplier = decltype(std::declval().get_bias_learning_rate_multiplier()); template using has_set_bias_learning_rate_multiplier = decltype(std::declval().set_bias_learning_rate_multiplier(double{})); template using has_get_weight_decay_multiplier = decltype(std::declval().get_weight_decay_multiplier()); template using has_set_weight_decay_multiplier = decltype(std::declval().set_weight_decay_multiplier(double{})); template using has_get_bias_weight_decay_multiplier = decltype(std::declval().get_bias_weight_decay_multiplier()); template using has_set_bias_weight_decay_multiplier = decltype(std::declval().set_bias_weight_decay_multiplier(double{})); template using has_disable_bias = decltype(std::declval().disable_bias()); template using has_clean = decltype(std::declval().clean()); } // ---------------------------------------------------------------------------------------- template double get_learning_rate_multiplier(const T& obj) { return switch_(bools(is_detected{}), [&](true_t, auto _) { return _(obj).get_learning_rate_multiplier(); }, [](auto...) { return 1.0; } ); } template void set_learning_rate_multiplier( T& obj, double learning_rate_multiplier ) { DLIB_CASSERT(learning_rate_multiplier >= 0); switch_(bools(is_detected{}), [&](true_t, auto _) { _(obj).set_learning_rate_multiplier(learning_rate_multiplier); }, [](auto...) {/*no-op*/} ); } // ---------------------------------------------------------------------------------------- template double get_bias_learning_rate_multiplier(const T& obj) { return switch_(bools(is_detected{}), [&](true_t, auto _) { return _(obj).get_bias_learning_rate_multiplier(); }, [](auto...) { return 1.0; } ); } template void set_bias_learning_rate_multiplier( T& obj, double bias_learning_rate_multiplier ) { DLIB_CASSERT(bias_learning_rate_multiplier >= 0); switch_(bools(is_detected{}), [&](true_t, auto _) { _(obj).set_bias_learning_rate_multiplier(bias_learning_rate_multiplier); }, [](auto...) {/*no-op*/} ); } // ---------------------------------------------------------------------------------------- template double get_weight_decay_multiplier(const T& obj) { return switch_(bools(is_detected{}), [&](true_t, auto _) { return _(obj).get_weight_decay_multiplier(); }, [](auto...) { return 1.0; } ); } template void set_weight_decay_multiplier( T& obj, double weight_decay_multiplier ) { DLIB_CASSERT(weight_decay_multiplier >= 0); switch_(bools(is_detected{}), [&](true_t, auto _) { _(obj).set_weight_decay_multiplier(weight_decay_multiplier); }, [](auto...) {/*no-op*/} ); } // ---------------------------------------------------------------------------------------- template double get_bias_weight_decay_multiplier(const T& obj) { return switch_(bools(is_detected{}), [&](true_t, auto _) { return _(obj).get_bias_weight_decay_multiplier(); }, [](auto...) { return 1.0; } ); } template void set_bias_weight_decay_multiplier( T& obj, double bias_weight_decay_multiplier ) { DLIB_CASSERT(bias_weight_decay_multiplier >= 0); switch_(bools(is_detected{}), [&](true_t, auto _) { _(obj).set_bias_weight_decay_multiplier(bias_weight_decay_multiplier); }, [](auto...) {/*no-op*/} ); } // ---------------------------------------------------------------------------------------- template void disable_bias( T& obj ) { switch_(bools(is_detected{}), [&](true_t, auto _) { _(obj).disable_bias(); }, [](auto...) { /*no-op*/ } ); } // ---------------------------------------------------------------------------------------- template void call_clean_method_if_exists(T& obj) /*! ensures - calls obj.clean() if obj has a .clean() method. !*/ { switch_(bools(is_detected{}), [&](true_t, auto _) { _(obj).clean(); }, [](auto...) { /*no-op*/ } ); } // ---------------------------------------------------------------------------------------- namespace impl { class repeat_input_layer { /*! None of the declarations in this object are really used. The only reason it exists is to allow the repeat object to use a special input layer in its internal networks which will cause add_tag_layer objects that happen to be right at the input to not create copies of their input tensors. So introducing the repeat_input_layer object allows us to optimize the implementation of add_tag_layer for a special case that arises when it's used in the context of the repeat layer. !*/ public: typedef int input_type; template void to_tensor ( forward_iterator , forward_iterator , resizable_tensor& ) const { } friend void serialize(const repeat_input_layer&, std::ostream&){} friend void deserialize(repeat_input_layer&, std::istream&){} friend std::ostream& operator<<(std::ostream& out, const repeat_input_layer&) { return out; } }; inline std::string tensor_to_str ( const tensor& t, int& min_length ) { if (t.size() == 0) return ""; std::ostringstream sout; sout << "output size=(num:"<< t.num_samples() << ", "; sout << "k:" << t.k() << ","; while (sout.tellp() < 28) sout << " "; sout << "nr:" << t.nr() << ","; while (sout.tellp() < 28+8) sout << " "; sout << "nc:" << t.nc() << ")"; while (sout.tellp() < min_length) sout << " "; min_length = sout.tellp(); sout << "\t"; return sout.str(); } } // ---------------------------------------------------------------------------------------- // Tell us if T is one of the special layer types (i.e. add_layer, repeat, add_tag_layer, or // add_skip_layer). template struct is_nonloss_layer_type : std::false_type {}; // Tell us if T is an instance of add_loss_layer. template struct is_loss_layer_type : std::false_type {}; // Tell us if T is an instance of add_layer template struct is_add_layer : std::false_type {}; namespace impl { template auto tuple_subset( const Tuple& item, std::index_sequence ) { return std::make_tuple(std::get(item)...); } template auto basic_tuple_tail( const std::tuple& item ) { return tuple_subset(item, pop_front_t>{}); } template auto tuple_flatten(const T& t) { return std::make_tuple(t); } template auto tuple_flatten( const std::tuple& item, std::index_sequence ) { return std::tuple_cat(tuple_flatten(std::get(item))...); } template auto tuple_flatten(const std::tuple& item) { return tuple_flatten(item, std::index_sequence_for{}); } template struct tuple_head_helper { typedef T type; static const type& get(const T& item) { return item; } }; template struct tuple_head_helper> { typedef typename tuple_head_helper::type type; static const type& get(const std::tuple& item) { return tuple_head_helper::get(std::get<0>(item)); } }; template struct alwaysbool { typedef bool type; }; // one more structure for VS 2015 UP3 support workaround template struct alwaysbool2 { typedef bool type; }; resizable_tensor& rt(); // The significance of a layer's backward method requiring forward's outputs is // that such as layer can't have an in-place layer stacked on top of it because // in-place layers overwrite the output of the layer they sit on top of. template constexpr auto backward_requires_forward_output( layer_type& layer, SUBNET& sub ) -> typename alwaysbool::type { return true; } template constexpr auto backward_requires_forward_output( layer_type& layer, SUBNET& sub ) -> typename alwaysbool::type { return false; } template constexpr auto backward_requires_forward_output( layer_type& layer, SUBNET& sub ) -> typename alwaysbool::type { return true; } template constexpr auto backward_requires_forward_output( layer_type& layer, SUBNET& sub ) -> typename alwaysbool::type { return false; } template constexpr auto has_inplace_backward( layer_type& layer, SUBNET& sub ) -> typename alwaysbool2::type { return false; } template constexpr auto has_inplace_backward( layer_type& layer, SUBNET& sub ) -> typename alwaysbool2::type { return false; } template constexpr auto has_inplace_backward( layer_type& layer, SUBNET& sub ) -> typename alwaysbool2::type { return true; } template constexpr auto has_inplace_backward( layer_type& layer, SUBNET& sub ) -> typename alwaysbool2::type { return true; } template constexpr auto is_inplace_layer( layer_type& layer, const SUBNET& sub ) -> typename alwaysbool2::type { return false; } template constexpr auto is_inplace_layer( layer_type& layer, const SUBNET& sub ) -> typename alwaysbool::type { return true; } template auto call_layer_backward( layer_type& layer, const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& params_grad ) -> decltype(layer.backward(computed_output,gradient_input,sub,params_grad)) { layer.backward(computed_output,gradient_input,sub,params_grad); } template auto call_layer_backward( layer_type& layer, const tensor& , const tensor& gradient_input, SUBNET& sub, tensor& params_grad ) -> decltype(layer.backward(gradient_input,sub,params_grad)) { layer.backward(gradient_input,sub,params_grad); } template auto call_layer_backward( layer_type& layer, const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& params_grad ) -> decltype(layer.backward_inplace(computed_output,gradient_input,sub.get_gradient_input(),params_grad)) { layer.backward_inplace(computed_output,gradient_input,sub.get_gradient_input(),params_grad); } template auto call_layer_backward( layer_type& layer, const tensor& , const tensor& gradient_input, SUBNET& sub, tensor& params_grad ) -> decltype(layer.backward_inplace(gradient_input,sub.get_gradient_input(),params_grad)) { layer.backward_inplace(gradient_input,sub.get_gradient_input(),params_grad); } template auto call_layer_forward( layer_type& layer, const SUBNET& sub, tensor& /*data_output*/ ) -> decltype(layer.forward(sub,rt())) { // This overload of call_layer_forward() is here because this template // naturally gets instantiated but only on code paths that never get executed. // So rather than writing a bunch of hard to read template magic around call // sites we just have this overload that doesn't do anything (and an assert to // make sure that's the case). DLIB_CASSERT(false, "This should never happen"); } template auto call_layer_forward( layer_type& layer, const SUBNET& sub, resizable_tensor& data_output ) -> decltype(layer.forward(sub,data_output)) { layer.forward(sub,data_output); } template auto call_layer_forward( layer_type& layer, const SUBNET& sub, tensor& data_output ) -> decltype(layer.forward_inplace(sub.get_output(),data_output)) { layer.forward_inplace(sub.get_output(),data_output); } template auto call_layer_forward( layer_type& layer, const SUBNET& sub, resizable_tensor& data_output ) -> decltype(layer.forward_inplace(sub.get_output(),data_output)) { if (!have_same_dimensions(data_output, sub.get_output())) data_output.copy_size(sub.get_output()); layer.forward_inplace(sub.get_output(),static_cast(data_output)); } } // end namespace impl template auto tuple_head ( const std::tuple& item ) { return impl::tuple_head_helper>::get(item); } template auto tuple_tail( const std::tuple& item ) { return impl::basic_tuple_tail(impl::tuple_flatten(item)); } inline std::tuple<> tuple_tail( const std::tuple<>& item ) { return item; } // ---------------------------------------------------------------------------------------- template class sstack { public: typedef T value_type; sstack() = delete; sstack ( T* data_, size_t s ) : data(data_), mysize(s) {} const T& top() const { DLIB_CASSERT(size() != 0, "You can't call top() on an empty stack"); return *data; } T& top() { DLIB_CASSERT(size() != 0, "You can't call top() on an empty stack"); return *data; } size_t size() const { return mysize; } sstack pop(size_t num=1) { DLIB_CASSERT(num <= size(), "You can't pop more things from the stack than it has in it."); return sstack(data+num, mysize-num); } private: T* data; size_t mysize; }; template sstack make_sstack(std::vector& item) { return sstack(item.data(), item.size()); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- namespace dimpl { template class subnet_wrapper { /*! WHAT THIS OBJECT REPRESENTS This is a tool that makes an add_layer or add_loss_layer object expose only the part of its interface defined by the SUBNET type in layers_abstract.h. This way, when we pass subnetwork objects to the layer callbacks those callbacks won't be able to interact with the subnetworks in a way other than specified by the SUBNET interface spec. We also allow the top layer of a subnet_wrapper stack to call the private_get_output() and private_get_gradient_input() functions. This way, layers that have had their output/gradient overwritten by in-place layers can only be accessed from the in-place layers that sit directly on top of them since those in-place layers are the only layers that know how to interact with them properly. !*/ public: subnet_wrapper(const subnet_wrapper&) = delete; subnet_wrapper& operator=(const subnet_wrapper&) = delete; subnet_wrapper(T& l_, unsigned int sef) : l(l_),_sample_expansion_factor(sef) {} // Not much here because in this case T is one of the input layer types // that doesn't have anything in it. typedef T layer_details_type; typedef T input_layer_type; const layer_details_type& layer_details() const { return l; } const input_layer_type& input_layer() const { return l; } input_layer_type& input_layer() { return l; } unsigned int sample_expansion_factor() const { return _sample_expansion_factor; } private: T& l; unsigned int _sample_expansion_factor; }; template class subnet_wrapper::value>::type> { public: subnet_wrapper(const subnet_wrapper&) = delete; subnet_wrapper& operator=(const subnet_wrapper&) = delete; typedef T wrapped_type; const static size_t num_computational_layers = T::num_computational_layers; const static size_t num_layers = T::num_layers; typedef typename T::layer_details_type layer_details_type; typedef typename T::input_layer_type input_layer_type; subnet_wrapper(T& l_, unsigned int = 0) : l(l_),subnetwork(l.subnet(), l.sample_expansion_factor()) {} const tensor& get_output() const { return l.private_get_output(); } tensor& get_gradient_input() { return l.private_get_gradient_input(); } const layer_details_type& layer_details() const { return l.layer_details(); } const subnet_wrapper& subnet() const { return subnetwork; } subnet_wrapper& subnet() { return subnetwork; } unsigned int sample_expansion_factor() const { return l.sample_expansion_factor(); } const input_layer_type& input_layer() const { return l.input_layer(); } input_layer_type& input_layer() { return l.input_layer(); } private: T& l; subnet_wrapper subnetwork; }; template class subnet_wrapper::value>::type> { public: subnet_wrapper(const subnet_wrapper&) = delete; subnet_wrapper& operator=(const subnet_wrapper&) = delete; typedef T wrapped_type; const static size_t num_computational_layers = T::num_computational_layers; const static size_t num_layers = T::num_layers; typedef typename T::layer_details_type layer_details_type; typedef typename T::input_layer_type input_layer_type; subnet_wrapper(T& l_, unsigned int = 0) : l(l_),subnetwork(l.subnet(), l.sample_expansion_factor()) {} const tensor& get_output() const { return l.get_output(); } tensor& get_gradient_input() { return l.get_gradient_input(); } const layer_details_type& layer_details() const { return l.layer_details(); } const subnet_wrapper& subnet() const { return subnetwork; } subnet_wrapper& subnet() { return subnetwork; } unsigned int sample_expansion_factor() const { return l.sample_expansion_factor(); } const input_layer_type& input_layer() const { return l.input_layer(); } input_layer_type& input_layer() { return l.input_layer(); } private: T& l; subnet_wrapper subnetwork; }; } // ---------------------------------------------------------------------------------------- enum class zero_gradients : uint8_t { no = 0, yes = 1 }; // ---------------------------------------------------------------------------------------- template class add_layer; template void serialize(const add_layer& item, std::ostream& out); template void deserialize(add_layer& item, std::istream& in); template struct is_nonloss_layer_type> : std::true_type {}; template class add_layer::value>::type> { public: typedef LAYER_DETAILS layer_details_type; typedef SUBNET subnet_type; typedef typename subnet_type::input_layer_type input_layer_type; typedef typename subnet_type::input_type input_type; const static size_t num_layers = subnet_type::num_layers + 1; const static size_t num_computational_layers = subnet_type::num_computational_layers + 1; add_layer( ): subnetwork(new subnet_type()), this_layer_setup_called(false), gradient_input_is_stale(true), get_output_and_gradient_input_disabled(false) { if (this_layer_operates_inplace()) subnetwork->disable_output_and_gradient_getters(); } add_layer(const add_layer& item) { details = item.details; subnetwork.reset(new subnet_type(*item.subnetwork)); this_layer_setup_called = item.this_layer_setup_called; gradient_input_is_stale = item.gradient_input_is_stale; get_output_and_gradient_input_disabled = item.get_output_and_gradient_input_disabled; x_grad = item.x_grad; cached_output = item.cached_output; params_grad = item.params_grad; temp_tensor = item.temp_tensor; } add_layer& operator=(const add_layer& item) { add_layer(item).swap(*this); return *this;} add_layer(add_layer&& item) : add_layer() { swap(item); } add_layer& operator=(add_layer&& item) { swap(item); return *this; } template friend class add_layer; template friend class dimpl::subnet_wrapper; template friend class add_tag_layer; template class T, typename U> friend class add_skip_layer; template class L, typename S> friend class repeat; // Allow copying networks from one to another as long as their corresponding // layers can be constructed from each other. template add_layer( const add_layer& item ) : details(item.layer_details()), subnetwork(new subnet_type(item.subnet())), this_layer_setup_called(item.this_layer_setup_called), gradient_input_is_stale(item.gradient_input_is_stale), get_output_and_gradient_input_disabled(item.get_output_and_gradient_input_disabled), x_grad(item.x_grad), cached_output(item.cached_output) { if (this_layer_operates_inplace()) subnetwork->disable_output_and_gradient_getters(); } template add_layer( const LAYER_DETAILS& layer_det, T&& ...args ) : details(layer_det), subnetwork(new subnet_type(std::forward(args)...)), this_layer_setup_called(false), gradient_input_is_stale(true), get_output_and_gradient_input_disabled(false) { if (this_layer_operates_inplace()) subnetwork->disable_output_and_gradient_getters(); } template struct disable_forwarding_constr { const static bool value = std::is_constructible::value; }; template struct disable_forwarding_constr,U...> { const static bool value = disable_forwarding_constr::type...>::value; }; template struct disable_forwarding_constr,U...> { const static bool value = disable_forwarding_constr::type>::value; }; template struct disable_forwarding_constr,U...> { const static bool value = true; }; template struct disable_forwarding_constr> { const static bool value = true; }; template < typename ...T, typename = typename std::enable_if::type...>::value>::type > add_layer( T&& ...args ) : subnetwork(new subnet_type(std::forward(args)...)), this_layer_setup_called(false), gradient_input_is_stale(true), get_output_and_gradient_input_disabled(false) { if (this_layer_operates_inplace()) subnetwork->disable_output_and_gradient_getters(); } template add_layer( LAYER_DETAILS&& layer_det, T&& ...args ) : details(std::move(layer_det)), subnetwork(new subnet_type(std::forward(args)...)), this_layer_setup_called(false), gradient_input_is_stale(true), get_output_and_gradient_input_disabled(false) { if (this_layer_operates_inplace()) subnetwork->disable_output_and_gradient_getters(); } template add_layer( const std::tuple& layer_det, T&& ...args ) : details(tuple_head(layer_det)), subnetwork(new subnet_type(tuple_tail(layer_det),std::forward(args)...)), this_layer_setup_called(false), gradient_input_is_stale(true), get_output_and_gradient_input_disabled(false) { if (this_layer_operates_inplace()) subnetwork->disable_output_and_gradient_getters(); } template add_layer( std::tuple<>, const std::tuple& layer_det, T&& ...args ) : add_layer(layer_det,args...) { } add_layer ( std::tuple<> ) : add_layer() {} template add_layer( std::tuple<>, LAYER_DETAILS&& layer_det, T&& ...args ) : add_layer(layer_det, args...) { } template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { subnetwork->to_tensor(ibegin,iend,data); } template const tensor& operator() ( forward_iterator ibegin, forward_iterator iend ) { to_tensor(ibegin,iend,temp_tensor); return forward(temp_tensor); } const tensor& operator() (const input_type& x) { return (*this)(&x, &x+1); } const tensor& forward(const tensor& x) { subnetwork->forward(x); const dimpl::subnet_wrapper wsub(*subnetwork); if (!this_layer_setup_called) { details.setup(wsub); this_layer_setup_called = true; } if (this_layer_operates_inplace()) impl::call_layer_forward(details, wsub, private_get_output()); else impl::call_layer_forward(details, wsub, cached_output); gradient_input_is_stale = true; return private_get_output(); } private: tensor& private_get_output() const { if (const_cast(*this).this_layer_operates_inplace()) return subnetwork->private_get_output(); else return const_cast(cached_output); } tensor& private_get_gradient_input() { if (this_layer_operates_inplace()) { return subnetwork->private_get_gradient_input(); } else { if (gradient_input_is_stale) { gradient_input_is_stale = false; x_grad.copy_size(private_get_output()); x_grad = 0; } return x_grad; } } void disable_output_and_gradient_getters ( ) { get_output_and_gradient_input_disabled = true; } public: const tensor& get_output() const { if (get_output_and_gradient_input_disabled) throw dlib::error("Accessing this layer's get_output() is disabled because an in-place layer has been stacked on top of it."); return private_get_output(); } tensor& get_gradient_input() { if (get_output_and_gradient_input_disabled) throw dlib::error("Accessing this layer's get_gradient_input() is disabled because an in-place layer has been stacked on top of it."); return private_get_gradient_input(); } const tensor& get_final_data_gradient( ) const { return subnetwork->get_final_data_gradient(); } void back_propagate_error( const tensor& x, zero_gradients zero_grads = zero_gradients::yes ) { back_propagate_error(x, private_get_gradient_input(), zero_grads); } void back_propagate_error( const tensor& x, const tensor& gradient_input, zero_gradients zero_grads = zero_gradients::yes ) { dimpl::subnet_wrapper wsub(*subnetwork); params_grad.copy_size(details.get_layer_params()); impl::call_layer_backward(details, private_get_output(), gradient_input, wsub, static_cast(params_grad)); subnetwork->back_propagate_error(x, zero_grads); // zero out get_gradient_input() gradient_input_is_stale = zero_grads == zero_gradients::yes; } template void update_parameters(sstack solvers, double learning_rate) { DLIB_CASSERT(solvers.size()>=num_computational_layers); // Don't try to adjust the parameters if this layer doesn't have any or the // learning rate is disabled for this layer. if (params_grad.size() != 0 && get_learning_rate_multiplier(details) != 0) { const tensor& step = solvers.top()(learning_rate, details, static_cast(params_grad)); tt::add(details.get_layer_params(), details.get_layer_params(), step); } subnetwork->update_parameters(solvers.pop(), learning_rate); } template void update_parameters(std::vector& solvers, double learning_rate) { update_parameters(make_sstack(solvers), learning_rate); } const tensor& get_parameter_gradient( ) const { return params_grad; } tensor& get_parameter_gradient ( ) { return params_grad; } const subnet_type& subnet() const { return *subnetwork; } subnet_type& subnet() { return *subnetwork; } const input_layer_type& input_layer() const { return subnet().input_layer(); } input_layer_type& input_layer() { return subnet().input_layer(); } const layer_details_type& layer_details() const { return details; } layer_details_type& layer_details() { return details; } unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } void set_gradient_inputs_to_zero() { gradient_input_is_stale = true; subnetwork->set_gradient_inputs_to_zero(); } void clean() { x_grad.clear(); cached_output.clear(); params_grad.clear(); temp_tensor.clear(); gradient_input_is_stale = true; subnetwork->clean(); call_clean_method_if_exists(details); } friend void serialize(const add_layer& item, std::ostream& out) { int version = 2; serialize(version, out); serialize(*item.subnetwork, out); serialize(item.details, out); serialize(item.this_layer_setup_called, out); serialize(item.gradient_input_is_stale, out); serialize(item.get_output_and_gradient_input_disabled, out); serialize(item.x_grad, out); serialize(item.cached_output, out); serialize(item.params_grad, out); } friend void deserialize(add_layer& item, std::istream& in) { int version = 0; deserialize(version, in); if (!(1 <= version && version <= 2)) throw serialization_error("Unexpected version found while deserializing dlib::add_layer."); deserialize(*item.subnetwork, in); deserialize(item.details, in); deserialize(item.this_layer_setup_called, in); deserialize(item.gradient_input_is_stale, in); deserialize(item.get_output_and_gradient_input_disabled, in); deserialize(item.x_grad, in); deserialize(item.cached_output, in); if (version == 2) deserialize(item.params_grad, in); } friend std::ostream& operator<< (std::ostream& out, const add_layer& item) { int min_length = 0; item.print(out, 0, min_length); return out; } void print (std::ostream& out, unsigned long idx, int& min_length) const { out << "layer<" << idx << ">\t" << impl::tensor_to_str(private_get_output(), min_length) << layer_details() << "\n"; subnet().print(out, idx+1, min_length); } private: bool this_layer_operates_inplace( ) { // This layer can run in-place if it's an in-place capable layer and also if // the layer it's on top of doesn't need its own output tensor (since in-place // layers overwrite that tensor) return impl::is_inplace_layer(details, *subnetwork) && !subnetwork->this_layer_requires_forward_output(); } bool this_layer_requires_forward_output( ) { return impl::backward_requires_forward_output(details, *subnetwork); } void swap(add_layer& item) { std::swap(subnetwork,item.subnetwork); std::swap(details, item.details); std::swap(this_layer_setup_called, item.this_layer_setup_called); std::swap(gradient_input_is_stale, item.gradient_input_is_stale); std::swap(get_output_and_gradient_input_disabled, item.get_output_and_gradient_input_disabled); std::swap(x_grad, item.x_grad); std::swap(cached_output, item.cached_output); std::swap(params_grad, item.params_grad); } LAYER_DETAILS details; std::unique_ptr subnetwork; bool this_layer_setup_called; bool gradient_input_is_stale; bool get_output_and_gradient_input_disabled; // Note that if this_layer_operates_inplace()==true then x_grad and cached_output // are not used at all. Instead, this layer uses these variables from the lower // layer. resizable_tensor x_grad; resizable_tensor cached_output; resizable_tensor params_grad; // temp_tensor doesn't logically contribute to the state of this object. // It is here only to prevent it from being reallocated over and over. resizable_tensor temp_tensor; }; template struct is_add_layer> : std::true_type {}; template struct is_add_layer> : std::true_type {}; template struct is_add_layer&> : std::true_type {}; template struct is_add_layer&> : std::true_type {}; // ---------------------------------------------------------------------------------------- // This version of add_layer handles the special case where the subnetwork being given is // just an input layer object. template class add_layer { public: typedef LAYER_DETAILS layer_details_type; typedef INPUT_LAYER subnet_type; typedef INPUT_LAYER input_layer_type; typedef typename INPUT_LAYER::input_type input_type; const static size_t num_layers = 2; const static size_t num_computational_layers = 1; add_layer( ): this_layer_setup_called(false), gradient_input_is_stale(true), get_output_and_gradient_input_disabled(false), _sample_expansion_factor(0) {} add_layer(const add_layer&) = default; add_layer(add_layer&& item) : add_layer() { swap(item); } add_layer& operator=(const add_layer&) = default; add_layer& operator=(add_layer&& item) { swap(item); return *this; } template friend class add_layer; template friend class dimpl::subnet_wrapper; template friend class add_tag_layer; template class T, typename U> friend class add_skip_layer; template class L, typename S> friend class repeat; // Allow copying networks from one to another as long as their corresponding // layers can be constructed from each other. template add_layer( const add_layer& item ): input_layer_(item.subnet()), details(item.layer_details()), this_layer_setup_called(item.this_layer_setup_called), gradient_input_is_stale(item.gradient_input_is_stale), get_output_and_gradient_input_disabled(false), _sample_expansion_factor(item._sample_expansion_factor), x_grad(item.x_grad), cached_output(item.cached_output), grad_final(item.grad_final) { } add_layer( const LAYER_DETAILS& layer_det ) : details(layer_det), this_layer_setup_called(false), gradient_input_is_stale(true), get_output_and_gradient_input_disabled(false), _sample_expansion_factor(0) {} add_layer( const INPUT_LAYER& il ) : input_layer_(il), this_layer_setup_called(false), gradient_input_is_stale(true), get_output_and_gradient_input_disabled(false), _sample_expansion_factor(0) {} add_layer( LAYER_DETAILS&& layer_det ) : details(std::move(layer_det)), this_layer_setup_called(false), gradient_input_is_stale(true), get_output_and_gradient_input_disabled(false), _sample_expansion_factor(0) {} add_layer( LAYER_DETAILS layer_det, INPUT_LAYER il ) : details(std::move(layer_det)), input_layer_(std::move(il)), this_layer_setup_called(false), gradient_input_is_stale(true), get_output_and_gradient_input_disabled(false), _sample_expansion_factor(0) {} add_layer( std::tuple<>, const LAYER_DETAILS& layer_det ) : add_layer(layer_det) {} add_layer( std::tuple<>, LAYER_DETAILS&& layer_det ) : add_layer(layer_det) {} add_layer( std::tuple<>, LAYER_DETAILS layer_det, INPUT_LAYER il ) : add_layer(layer_det,il) {} add_layer( const std::tuple& layer_det ) : add_layer(tuple_head(layer_det)) {} add_layer( const std::tuple& layer_det, INPUT_LAYER il ) : add_layer(tuple_head(layer_det),il) {} template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { input_layer_.to_tensor(ibegin, iend, data); // make sure the input layer's to_tensor() function is implemented properly. DLIB_CASSERT(data.num_samples() >= std::distance(ibegin,iend), "The input layer can't produce fewer output tensors than there are inputs."); DLIB_CASSERT(data.num_samples()%std::distance(ibegin,iend) == 0, "The number of tensors produced by the input layer must be an integer multiple of the number of input objects."); _sample_expansion_factor = data.num_samples()/std::distance(ibegin,iend); data.async_copy_to_device(); } template const tensor& operator() ( forward_iterator ibegin, forward_iterator iend ) { to_tensor(ibegin,iend,temp_tensor); return forward(temp_tensor); } const tensor& operator() (const input_type& x) { return (*this)(&x, &x+1); } const tensor& forward (const tensor& x) { DLIB_CASSERT(sample_expansion_factor() != 0, "You must call to_tensor() before this function can be used."); DLIB_CASSERT(x.num_samples()%sample_expansion_factor() == 0); subnet_wrapper wsub(x, grad_final, _sample_expansion_factor); if (!this_layer_setup_called) { details.setup(wsub); this_layer_setup_called = true; } impl::call_layer_forward(details, wsub, cached_output); gradient_input_is_stale = true; return private_get_output(); } private: tensor& private_get_output() const { return const_cast(cached_output); } tensor& private_get_gradient_input() { if (gradient_input_is_stale) { gradient_input_is_stale = false; x_grad.copy_size(private_get_output()); x_grad = 0; } return x_grad; } void disable_output_and_gradient_getters ( ) { get_output_and_gradient_input_disabled = true; } public: const tensor& get_output() const { if (get_output_and_gradient_input_disabled) throw dlib::error("Accessing this layer's get_output() is disabled because an in-place layer has been stacked on top of it."); return private_get_output(); } tensor& get_gradient_input() { if (get_output_and_gradient_input_disabled) throw dlib::error("Accessing this layer's get_gradient_input() is disabled because an in-place layer has been stacked on top of it."); return private_get_gradient_input(); } const tensor& get_final_data_gradient( ) const { return grad_final; } void back_propagate_error( const tensor& x, zero_gradients zero_grads = zero_gradients::yes ) { back_propagate_error(x, private_get_gradient_input(), zero_grads); } void back_propagate_error( const tensor& x, const tensor& gradient_input, zero_gradients zero_grads = zero_gradients::yes ) { // make sure grad_final is initialized to 0 if (!have_same_dimensions(x, grad_final)) grad_final.copy_size(x); grad_final = 0; subnet_wrapper wsub(x, grad_final, _sample_expansion_factor); params_grad.copy_size(details.get_layer_params()); impl::call_layer_backward(details, private_get_output(), gradient_input, wsub, static_cast(params_grad)); // zero out get_gradient_input() gradient_input_is_stale = zero_grads == zero_gradients::yes; } template void update_parameters(sstack solvers, double learning_rate) { DLIB_CASSERT(solvers.size()>=num_computational_layers); // Don't try to adjust the parameters if this layer doesn't have any or the // learning rate is disabled for this layer. if (params_grad.size() != 0 && get_learning_rate_multiplier(details) != 0) { const tensor& step = solvers.top()(learning_rate, details, static_cast(params_grad)); tt::add(details.get_layer_params(), details.get_layer_params(), step); } } template void update_parameters(std::vector& solvers, double learning_rate) { update_parameters(make_sstack(solvers), learning_rate); } const tensor& get_parameter_gradient( ) const { return params_grad; } tensor& get_parameter_gradient ( ) { return params_grad; } const subnet_type& subnet() const { return input_layer_; } subnet_type& subnet() { return input_layer_; } const subnet_type& input_layer() const { return input_layer_; } subnet_type& input_layer() { return input_layer_; } const layer_details_type& layer_details() const { return details; } layer_details_type& layer_details() { return details; } unsigned int sample_expansion_factor() const { return _sample_expansion_factor; } void set_gradient_inputs_to_zero() { gradient_input_is_stale = true; } void clean() { x_grad.clear(); grad_final.clear(); cached_output.clear(); params_grad.clear(); temp_tensor.clear(); gradient_input_is_stale = true; call_clean_method_if_exists(details); } friend void serialize(const add_layer& item, std::ostream& out) { int version = 3; serialize(version, out); serialize(item.input_layer_, out); serialize(item.details, out); serialize(item.this_layer_setup_called, out); serialize(item.gradient_input_is_stale, out); serialize(item.get_output_and_gradient_input_disabled, out); serialize(item.x_grad, out); serialize(item.cached_output, out); serialize(item.grad_final, out); serialize(item._sample_expansion_factor, out); } friend void deserialize(add_layer& item, std::istream& in) { int version = 0; deserialize(version, in); if (!(2 <= version && version <= 3)) throw serialization_error("Unexpected version found while deserializing dlib::add_layer."); deserialize(item.input_layer_, in); deserialize(item.details, in); deserialize(item.this_layer_setup_called, in); deserialize(item.gradient_input_is_stale, in); deserialize(item.get_output_and_gradient_input_disabled, in); deserialize(item.x_grad, in); deserialize(item.cached_output, in); deserialize(item.grad_final, in); if (version >= 3) deserialize(item._sample_expansion_factor, in); else item._sample_expansion_factor = 1; // all layer types set this to 1 in older dlib versions, so that's what we put here. } friend std::ostream& operator<< (std::ostream& out, const add_layer& item) { int min_length = 0; item.print(out, 0, min_length); return out; } void print (std::ostream& out, unsigned long idx, int& min_length) const { out << "layer<" << idx << ">\t" << impl::tensor_to_str(private_get_output(), min_length) << layer_details() << "\n"; // Don't print the repeat_input_layer since it doesn't exist from the user's // point of view. It's just an artifact of how repeat<> works. if (!std::is_same::value) out << "layer<" << idx+1 << ">\t" << subnet() << "\n"; } private: bool this_layer_requires_forward_output( ) { subnet_wrapper wsub(grad_final, grad_final, _sample_expansion_factor); return impl::backward_requires_forward_output(details, wsub); } class subnet_wrapper { public: subnet_wrapper(const tensor& x_, resizable_tensor& grad_final_, unsigned int sef) : x(x_), grad_final(grad_final_), _sample_expansion_factor(sef) {} subnet_wrapper(const subnet_wrapper&) = delete; subnet_wrapper& operator=(const subnet_wrapper&) = delete; unsigned int sample_expansion_factor() const { return _sample_expansion_factor;} const tensor& get_output() const { return x; } tensor& get_gradient_input() { if (!have_same_dimensions(x, grad_final)) { grad_final.copy_size(x); grad_final = 0; } return grad_final; } private: const tensor& x; resizable_tensor& grad_final; unsigned int _sample_expansion_factor; }; void swap(add_layer& item) { std::swap(input_layer_, item.input_layer_); std::swap(details, item.details); std::swap(this_layer_setup_called, item.this_layer_setup_called); std::swap(gradient_input_is_stale, item.gradient_input_is_stale); std::swap(get_output_and_gradient_input_disabled, item.get_output_and_gradient_input_disabled); std::swap(x_grad, item.x_grad); std::swap(cached_output, item.cached_output); std::swap(grad_final, item.grad_final); std::swap(_sample_expansion_factor, item._sample_expansion_factor); } subnet_type input_layer_; LAYER_DETAILS details; bool this_layer_setup_called; bool gradient_input_is_stale; bool get_output_and_gradient_input_disabled; mutable unsigned int _sample_expansion_factor; resizable_tensor x_grad; resizable_tensor cached_output; resizable_tensor grad_final; // The following 2 objects don't logically contribute to the state of this class. // They are only here to prevent them from being reallocated over and over in // member functions. resizable_tensor params_grad; resizable_tensor temp_tensor; }; // ---------------------------------------------------------------------------------------- template class add_tag_layer; template class tag> struct tag_id { const static unsigned long id = tag::id; }; template class add_tag_layer::value>::type> { public: typedef SUBNET subnet_type; typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_layer_type input_layer_type; typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper. const static size_t num_layers = subnet_type::num_layers + 1; const static size_t num_computational_layers = subnet_type::num_computational_layers; const static unsigned long id = ID; add_tag_layer() {}; add_tag_layer(const add_tag_layer&) = default; add_tag_layer(add_tag_layer&&) = default; add_tag_layer& operator=(add_tag_layer&&) = default; add_tag_layer& operator=(const add_tag_layer&) = default; template add_tag_layer( const add_tag_layer& item ) : subnetwork(item.subnet()) {} template add_tag_layer( T ...args ) : subnetwork(std::move(args)...) { } template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { subnetwork.to_tensor(ibegin,iend,data); } template const tensor& operator() ( forward_iterator ibegin, forward_iterator iend ) { return subnetwork(ibegin,iend); } const tensor& operator() (const input_type& x) { return subnetwork(x); } const tensor& forward(const tensor& x) { return subnetwork.forward(x); } const tensor& get_output() const { return subnetwork.get_output(); } tensor& get_gradient_input() { return subnetwork.get_gradient_input(); } const tensor& get_final_data_gradient( ) const { return subnetwork.get_final_data_gradient(); } void back_propagate_error( const tensor& x, zero_gradients zero_grads = zero_gradients::yes ) { subnetwork.back_propagate_error(x, zero_grads); } void back_propagate_error( const tensor& x, const tensor& gradient_input, zero_gradients zero_grads = zero_gradients::yes ) { subnetwork.back_propagate_error(x,gradient_input, zero_grads); } template void update_parameters(sstack solvers, double learning_rate) { subnetwork.update_parameters(solvers, learning_rate); } template void update_parameters(std::vector& solvers, double learning_rate) { update_parameters(make_sstack(solvers), learning_rate); } const tensor& get_parameter_gradient( ) const { return params_grad; } tensor& get_parameter_gradient ( ) { return params_grad; } const subnet_type& subnet() const { return subnetwork; } subnet_type& subnet() { return subnetwork; } const input_layer_type& input_layer() const { return subnet().input_layer(); } input_layer_type& input_layer() { return subnet().input_layer(); } unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } void set_gradient_inputs_to_zero() { subnetwork.set_gradient_inputs_to_zero(); } void clean() { subnetwork.clean(); } friend void serialize(const add_tag_layer& item, std::ostream& out) { int version = 1; serialize(version, out); serialize(item.subnetwork, out); } friend void deserialize(add_tag_layer& item, std::istream& in) { int version = 0; deserialize(version, in); if (version != 1) throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer."); deserialize(item.subnetwork, in); } friend std::ostream& operator<< (std::ostream& out, const add_tag_layer& item) { int min_length = 0; item.print(out, 0, min_length); return out; } void print (std::ostream& out, unsigned long idx, int& min_length) const { out << "layer<" << idx << ">\t" << impl::tensor_to_str(private_get_output(), min_length) << "tag" << ID << "\n"; subnet().print(out, idx+1, min_length); } private: template friend class add_layer; template friend class dimpl::subnet_wrapper; template friend class add_tag_layer; template class T, typename U> friend class add_skip_layer; template class L, typename S> friend class repeat; // You wouldn't put a tag on a layer if you didn't want to access its forward // outputs. So this is always true. bool this_layer_requires_forward_output( ) { return true; } void disable_output_and_gradient_getters ( ) { // This should never happen because only inplace layers call // disable_output_and_gradient_getters(), however, putting a tag layer right // before an inplace layer basically means you don't want the following layer // to operate in place. So the inplace layer should turn itself into an // out-of-place layer and not call disable_output_and_gradient_getters(). DLIB_CASSERT(false,"This should never happen"); } tensor& private_get_output() const { return subnetwork.private_get_output(); } tensor& private_get_gradient_input() { return subnetwork.private_get_gradient_input(); } subnet_type subnetwork; // This member doesn't logically contribute to the state of the object since it is // always empty. It's just here so we can have the get_parameter_gradient() methods // which have to return something. So they return this empty tensor. resizable_tensor params_grad; }; // ---------------------------------------------------------------------------------------- template struct decorator_repeat_group { decorator_repeat_group( T&& ...args ) : data(std::forward(args)...) {} std::tuple data; }; template decorator_repeat_group repeat_group ( T&& ...args ) { return decorator_repeat_group(std::forward(args)...); } template < size_t num, template class REPEATED_LAYER, typename SUBNET > class repeat { static_assert(num > 0, "You can't have a layer repeated 0 times."); public: typedef SUBNET subnet_type; typedef typename SUBNET::input_type input_type; typedef typename subnet_type::input_layer_type input_layer_type; typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper. const static size_t comp_layers_in_each_group = (REPEATED_LAYER::num_computational_layers-SUBNET::num_computational_layers); const static size_t comp_layers_in_repeated_group = comp_layers_in_each_group*num; const static size_t num_computational_layers = comp_layers_in_repeated_group + SUBNET::num_computational_layers; const static size_t layers_in_each_group = (REPEATED_LAYER::num_layers-SUBNET::num_layers); const static size_t layers_in_repeated_group = layers_in_each_group*num; const static size_t num_layers = subnet_type::num_layers + layers_in_repeated_group; typedef REPEATED_LAYER repeated_layer_type; repeat( ) : details(num) { } size_t num_repetitions ( ) const { return num; } const repeated_layer_type& get_repeated_layer ( size_t i ) const { DLIB_CASSERT(i < num_repetitions()); return details[i]; } repeated_layer_type& get_repeated_layer ( size_t i ) { DLIB_CASSERT(i < num_repetitions()); return details[i]; } repeat(const repeat&) = default; repeat(repeat&&) = default; repeat& operator=(repeat&&) = default; repeat& operator=(const repeat&) = default; template class T, typename U> repeat( const repeat& item ) : subnetwork(item.subnetwork) { for (auto&& d : item.details) details.emplace_back(d); } template repeat( T arg1, U ...args2 ): details(num, std::move(arg1)), subnetwork(std::move(args2)...) { } template repeat( decorator_repeat_group&& arg1, U ...args2 ): details(num, arg1.data), subnetwork(std::move(args2)...) { } template repeat( std::tuple<>, T arg1, U ...args2 ): details(num, std::move(arg1)), subnetwork(std::move(args2)...) { } template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { subnetwork.to_tensor(ibegin,iend,data); // call to_tensor on the networks in details just to populate the // _sample_expansion_factor values in those networks. Other than that this // call is a noop. for (auto& d : details) d.to_tensor(ibegin, iend, data); } template const tensor& operator() ( forward_iterator ibegin, forward_iterator iend ) { to_tensor(ibegin,iend,temp_tensor); return forward(temp_tensor); } const tensor& operator() (const input_type& x) { return (*this)(&x, &x+1); } const tensor& forward(const tensor& x) { subnetwork.forward(x); details[details.size()-1].forward(subnetwork.get_output()); for (long i = details.size()-2; i >= 0; --i) details[i].forward(details[i+1].get_output()); return private_get_output(); } private: tensor& private_get_output() const { return details[0].private_get_output(); } tensor& private_get_gradient_input() { return details[0].private_get_gradient_input(); } public: const tensor& get_output() const { return details[0].get_output(); } tensor& get_gradient_input() { return details[0].get_gradient_input(); } const tensor& get_final_data_gradient( ) const { return subnetwork.get_final_data_gradient(); } const tensor& get_parameter_gradient( ) const { return details[0].get_parameter_gradient(); } tensor& get_parameter_gradient ( ) { return details[0].get_parameter_gradient(); } void back_propagate_error( const tensor& x, zero_gradients zero_grads = zero_gradients::yes ) { back_propagate_error(x, private_get_gradient_input(), zero_grads); } void back_propagate_error( const tensor& x, const tensor& gradient_input, zero_gradients zero_grads = zero_gradients::yes ) { if (details.size() > 1) { details[0].back_propagate_error(details[1].get_output(), gradient_input, zero_grads); for (size_t i = 1; i < details.size(); ++i) { if (i+1 < details.size()) details[i].back_propagate_error(details[i+1].get_output(), details[i-1].get_final_data_gradient(), zero_grads); else details[i].back_propagate_error(subnetwork.get_output(), details[i-1].get_final_data_gradient(), zero_grads); } } else { details[0].back_propagate_error(subnetwork.get_output(), gradient_input, zero_grads); } subnetwork.back_propagate_error(x, details.back().get_final_data_gradient(), zero_grads); } template void update_parameters(sstack solvers, double learning_rate) { for (size_t i = 0; i < details.size(); ++i) details[i].update_parameters(solvers.pop(comp_layers_in_each_group*i),learning_rate); subnetwork.update_parameters(solvers.pop(comp_layers_in_each_group*details.size()),learning_rate); } template void update_parameters(std::vector& solvers, double learning_rate) { update_parameters(make_sstack(solvers), learning_rate); } const subnet_type& subnet() const { return subnetwork; } subnet_type& subnet() { return subnetwork; } const input_layer_type& input_layer() const { return subnet().input_layer(); } input_layer_type& input_layer() { return subnet().input_layer(); } unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } void set_gradient_inputs_to_zero() { subnetwork.set_gradient_inputs_to_zero(); } void clean() { temp_tensor.clear(); subnetwork.clean(); for (auto&& d : details) d.clean(); } friend void serialize(const repeat& item, std::ostream& out) { int version = 1; serialize(version, out); serialize(item.details, out); serialize(item.subnetwork, out); } friend void deserialize(repeat& item, std::istream& in) { int version = 0; deserialize(version, in); if (version != 1) throw serialization_error("Unexpected version found while deserializing dlib::repeat."); deserialize(item.details, in); deserialize(item.subnetwork, in); } friend std::ostream& operator<< (std::ostream& out, const repeat& item) { int min_length = 0; item.print(out, 0, min_length); return out; } void print (std::ostream& out, unsigned long idx, int& min_length) const { for (size_t i = 0; i < num_repetitions(); ++i) { get_repeated_layer(i).print(out, idx, min_length); idx += layers_in_each_group; } subnet().print(out, idx, min_length); } private: template friend class add_layer; template friend class dimpl::subnet_wrapper; template friend class add_tag_layer; template class T, typename U> friend class add_skip_layer; template class L, typename S> friend class repeat; bool this_layer_requires_forward_output( ) { return details[0].this_layer_requires_forward_output(); } void disable_output_and_gradient_getters ( ) { details[0].disable_output_and_gradient_getters(); } std::vector details; subnet_type subnetwork; // temp_tensor doesn't logically contribute to the state of this class. // It is here only to void needing to reallocate it over and over. resizable_tensor temp_tensor; }; template < size_t num, template class REPEATED_LAYER, typename SUBNET > struct is_nonloss_layer_type> : std::true_type {}; // ---------------------------------------------------------------------------------------- // This version of add_tag_layer handles the special case where the subnetwork being given // is just an input layer object. template class add_tag_layer { public: typedef INPUT_LAYER subnet_type; typedef typename subnet_type::input_type input_type; typedef INPUT_LAYER input_layer_type; typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper. const static size_t num_computational_layers = 0; const static size_t num_layers = 2; const static unsigned long id = ID; add_tag_layer():cached_output_ptr(nullptr),gradient_input_is_stale(true),_sample_expansion_factor(0) {} add_tag_layer(const add_tag_layer&) = default; add_tag_layer& operator=(const add_tag_layer&) = default; add_tag_layer(add_tag_layer&& item) : add_tag_layer() { swap(item); } add_tag_layer& operator=(add_tag_layer&& item) { swap(item); return *this; } template add_tag_layer( const add_tag_layer& item ) : input_layer_(item.subnet()), cached_output(item.cached_output), cached_output_ptr(nullptr), grad_final(item.grad_final), gradient_input_is_stale(item.gradient_input_is_stale), _sample_expansion_factor(0) {} template add_tag_layer( T ...args ) : input_layer_(std::move(args)...), cached_output_ptr(nullptr), gradient_input_is_stale(true), _sample_expansion_factor(0) { } add_tag_layer ( std::tuple<> ) : cached_output_ptr(nullptr), gradient_input_is_stale(true), _sample_expansion_factor(0) {} template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { input_layer_.to_tensor(ibegin,iend,data); // make sure the input layer's to_tensor() function is implemented properly. DLIB_CASSERT(data.num_samples() >= std::distance(ibegin,iend), "The input layer can't produce fewer output tensors than there are inputs."); DLIB_CASSERT(data.num_samples()%std::distance(ibegin,iend) == 0, "The number of tensors produced by the input layer must be an integer multiple of the number of input objects."); _sample_expansion_factor = data.num_samples()/std::distance(ibegin,iend); data.async_copy_to_device(); } unsigned int sample_expansion_factor() const { return _sample_expansion_factor; } template const tensor& operator() ( forward_iterator ibegin, forward_iterator iend ) { input_layer_.to_tensor(ibegin,iend,cached_output); cached_output_ptr = nullptr; return get_output(); } const tensor& operator() (const input_type& x) { return (*this)(&x, &x+1); } const tensor& forward(const tensor& x) { // If this tag is the first layer in one of the sub networks inside a repeat // layer then we don't want it to be creating copies of x. This is because, we // can just hold a pointer to x since the way repeat is constructed guarantees // that x will have a lifetime larger than this pointer. if (is_same_type::value) cached_output_ptr = const_cast(&x); else cached_output = x; gradient_input_is_stale = true; return get_output(); } const tensor& get_output() const { if (cached_output_ptr) return *cached_output_ptr; else return cached_output; } const tensor& get_final_data_gradient( ) const { return grad_final; } tensor& get_gradient_input() { if (!have_same_dimensions(get_output(), grad_final) || gradient_input_is_stale) { grad_final.copy_size(get_output()); grad_final = 0; gradient_input_is_stale = false; } return grad_final; } void back_propagate_error( const tensor& /*x*/, zero_gradients /*zero_grads*/ = zero_gradients::yes ) { // nothing to do } void back_propagate_error( const tensor& /*x*/, const tensor& /*gradient_input*/, zero_gradients /*zero_grads*/ = zero_gradients::yes ) { // nothing to do } template void update_parameters(sstack /*solvers*/, double /*learning_rate*/) { // nothing to do } template void update_parameters(std::vector& solvers, double learning_rate) { update_parameters(make_sstack(solvers), learning_rate); } const subnet_type& subnet() const { return input_layer_; } subnet_type& subnet() { return input_layer_; } const input_layer_type& input_layer() const { return input_layer_; } input_layer_type& input_layer() { return input_layer_; } void set_gradient_inputs_to_zero() { // nothing to do } void clean() { grad_final.clear(); cached_output.clear(); cached_output_ptr = 0; } friend void serialize(const add_tag_layer& item, std::ostream& out) { int version = 2; serialize(version, out); serialize(item.input_layer_, out); serialize(item.cached_output, out); serialize(item.grad_final, out); serialize(item.gradient_input_is_stale, out); serialize(item._sample_expansion_factor, out); } friend void deserialize(add_tag_layer& item, std::istream& in) { int version = 0; deserialize(version, in); if (!(1 <= version && version <= 2)) throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer."); deserialize(item.input_layer_, in); deserialize(item.cached_output, in); deserialize(item.grad_final, in); deserialize(item.gradient_input_is_stale, in); item.cached_output_ptr = nullptr; if (version >= 2) deserialize(item._sample_expansion_factor, in); else item._sample_expansion_factor = 1; // all layer types set this to 1 in older dlib versions, so that's what we put here. } friend std::ostream& operator<< (std::ostream& out, const add_tag_layer& item) { int min_length = 0; item.print(out, 0, min_length); return out; } void print (std::ostream& out, unsigned long idx, int& min_length) const { out << "layer<"<\t"< works. if (!std::is_same::value) out << "layer<"<< idx+1 << ">\t" << subnet() << "\n"; } private: template friend class add_layer; template friend class dimpl::subnet_wrapper; template friend class add_tag_layer; template class T, typename U> friend class add_skip_layer; template class L, typename S> friend class repeat; // You woudln't put a tag on a layer if you didn't want to access its forward // outputs. So this is always true. bool this_layer_requires_forward_output( ) { return true; } void disable_output_and_gradient_getters ( ) { // This should never happen because only inplace layers call // disable_output_and_gradient_getters(), however, putting a tag layer right // before an inplace layer basically means you don't want the following layer // to operate in place. So the inplace layer should turn itself into an // out-of-place layer and not call disable_output_and_gradient_getters(). DLIB_CASSERT(false,"This should never happen"); } tensor& private_get_output() const { return const_cast(get_output()); } tensor& private_get_gradient_input() { return get_gradient_input(); } void swap(add_tag_layer& item) { std::swap(input_layer_, item.input_layer_); std::swap(cached_output, item.cached_output); std::swap(cached_output_ptr, item.cached_output_ptr); std::swap(grad_final, item.grad_final); std::swap(gradient_input_is_stale, item.gradient_input_is_stale); std::swap(_sample_expansion_factor, item._sample_expansion_factor); } subnet_type input_layer_; resizable_tensor cached_output; tensor* cached_output_ptr; resizable_tensor grad_final; bool gradient_input_is_stale; mutable unsigned int _sample_expansion_factor; }; template struct is_nonloss_layer_type> : std::true_type {}; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template class add_loss_layer; class no_label_type { private: // We don't want anyone making these no_label_type objects. They are here only to // allow add_loss_layer::training_label_type and dnn_trainer::training_label_type // to exist which avoids needing to overload add_loss_layer and dnn_trainer for // supervised an unsupervised losses. It also can be a type to use in template // metaprogramming to indicate "no label". So here we make the constructor private // with the exception that add_loss_layer objects can make it (again, just to // simplify add_loss_layer's implementation). no_label_type(){}; template friend class add_loss_layer; template < typename net_type, typename solver_type > friend class dnn_trainer; }; // ---------------------------------------------------------------------------------------- template class add_loss_layer { template struct get_loss_layer_training_label_type { typedef no_label_type type; }; template struct get_loss_layer_training_label_type::type> { typedef typename T::training_label_type type; }; template struct get_loss_layer_output_label_type { typedef no_label_type type; }; template struct get_loss_layer_output_label_type::type> { typedef typename T::output_label_type type; }; public: typedef LOSS_DETAILS loss_details_type; typedef SUBNET subnet_type; typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_layer_type input_layer_type; const static size_t num_layers = subnet_type::num_layers + 1; // Note that the loss layer doesn't count as an additional computational layer. const static size_t num_computational_layers = subnet_type::num_computational_layers; typedef typename get_loss_layer_training_label_type::type training_label_type; typedef typename get_loss_layer_output_label_type::type output_label_type; static_assert(is_nonloss_layer_type::value, "SUBNET must be of type add_layer, add_skip_layer, or add_tag_layer."); add_loss_layer() {}; add_loss_layer(const add_loss_layer&) = default; add_loss_layer& operator=(const add_loss_layer&) = default; add_loss_layer(add_loss_layer&& item) : add_loss_layer() { swap(item); } add_loss_layer& operator=(add_loss_layer&& item) { swap(item); return *this; } template add_loss_layer( const add_loss_layer& item ) : loss(item.loss_details()), subnetwork(item.subnet()) {} template add_loss_layer( const LOSS_DETAILS& layer_det, T&& ...args ) : loss(layer_det), subnetwork(std::forward(args)...) { } template add_loss_layer( LOSS_DETAILS&& layer_det, T&& ...args ) : loss(std::move(layer_det)), subnetwork(std::forward(args)...) { } template struct disable_forwarding_constr { const static bool value = std::is_constructible::value; }; template struct disable_forwarding_constr> { const static bool value = true; }; template < typename ...T, typename = typename std::enable_if::type...>::value>::type > add_loss_layer( T&& ...args ) : subnetwork(std::forward(args)...) { } template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { subnetwork.to_tensor(ibegin,iend,data); } unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } template void operator() ( const tensor& x, output_iterator obegin ) { subnetwork.forward(x); const dimpl::subnet_wrapper wsub(subnetwork); loss.to_label(x, wsub, obegin); } template void operator() ( forward_iterator ibegin, forward_iterator iend, output_iterator obegin ) { to_tensor(ibegin,iend,temp_tensor); (*this)(temp_tensor, obegin); } const output_label_type& operator() (const input_type& x) { (*this)(&x, &x+1, &temp_label); return temp_label; } template const output_label_type& process (const input_type& x, T&& ...args) { to_tensor(&x,&x+1,temp_tensor); subnetwork.forward(temp_tensor); const dimpl::subnet_wrapper wsub(subnetwork); loss.to_label(temp_tensor, wsub, &temp_label, std::forward(args)...); return temp_label; } template std::vector process_batch (const iterable_type& data, size_t batch_size, T&& ...args) { std::vector results(std::distance(data.begin(), data.end())); auto o = results.begin(); auto i = data.begin(); auto num_remaining = results.size(); while(num_remaining != 0) { auto inc = std::min(batch_size, num_remaining); to_tensor(i,i+inc,temp_tensor); subnetwork.forward(temp_tensor); const dimpl::subnet_wrapper wsub(subnetwork); loss.to_label(temp_tensor, wsub, o, std::forward(args)...); i += inc; o += inc; num_remaining -= inc; } return results; } void back_propagate_error( const tensor& x, zero_gradients zero_grads = zero_gradients::yes ) { subnet().back_propagate_error(x, zero_grads); } void back_propagate_error( const tensor& x, const tensor& gradient_input, zero_gradients zero_grads = zero_gradients::yes ) { subnet().back_propagate_error(x, gradient_input, zero_grads); } const tensor& get_final_data_gradient( ) const { return subnet().get_final_data_gradient(); } const tensor& forward(const tensor& x) { return subnet().forward(x); } template std::vector operator() ( const iterable_type& data, size_t batch_size = 128 ) { std::vector results(std::distance(data.begin(), data.end())); auto o = results.begin(); auto i = data.begin(); auto num_remaining = results.size(); while(num_remaining != 0) { auto inc = std::min(batch_size, num_remaining); (*this)(i, i+inc, o); i += inc; o += inc; num_remaining -= inc; } return results; } template double compute_loss ( const tensor& x, label_iterator lbegin ) { subnetwork.forward(x); dimpl::subnet_wrapper wsub(subnetwork); return loss.compute_loss_value_and_gradient(x, lbegin, wsub); } template double compute_loss ( forward_iterator ibegin, forward_iterator iend, label_iterator lbegin ) { to_tensor(ibegin,iend,temp_tensor); return compute_loss(temp_tensor, lbegin); } double compute_loss ( const tensor& x ) { subnetwork.forward(x); dimpl::subnet_wrapper wsub(subnetwork); return loss.compute_loss_value_and_gradient(x, wsub); } template double compute_loss ( forward_iterator ibegin, forward_iterator iend ) { to_tensor(ibegin,iend,temp_tensor); return compute_loss(temp_tensor); } template double compute_parameter_gradients ( const tensor& x, label_iterator lbegin, zero_gradients zero_grads = zero_gradients::yes ) { subnetwork.forward(x); dimpl::subnet_wrapper wsub(subnetwork); double l = loss.compute_loss_value_and_gradient(x, lbegin, wsub); subnetwork.back_propagate_error(x, zero_grads); return l; } template double compute_parameter_gradients ( forward_iterator ibegin, forward_iterator iend, label_iterator lbegin, zero_gradients zero_grads = zero_gradients::yes ) { to_tensor(ibegin,iend,temp_tensor); return compute_parameter_gradients(temp_tensor, lbegin, zero_grads); } double compute_parameter_gradients ( const tensor& x, zero_gradients zero_grads = zero_gradients::yes ) { subnetwork.forward(x); dimpl::subnet_wrapper wsub(subnetwork); double l = loss.compute_loss_value_and_gradient(x, wsub); subnetwork.back_propagate_error(x, zero_grads); return l; } template double compute_parameter_gradients ( forward_iterator ibegin, forward_iterator iend, zero_gradients zero_grads = zero_gradients::yes ) { to_tensor(ibegin,iend,temp_tensor); return compute_parameter_gradients(temp_tensor, zero_grads); } template void update_parameters ( sstack solvers, double learning_rate ) { subnetwork.update_parameters(solvers, learning_rate); } template void update_parameters(std::vector& solvers, double learning_rate) { update_parameters(make_sstack(solvers), learning_rate); } const subnet_type& subnet() const { return subnetwork; } subnet_type& subnet() { return subnetwork; } const input_layer_type& input_layer() const { return subnet().input_layer(); } input_layer_type& input_layer() { return subnet().input_layer(); } const loss_details_type& loss_details() const { return loss; } loss_details_type& loss_details() { return loss; } void set_gradient_inputs_to_zero ( ) { subnetwork.set_gradient_inputs_to_zero(); } void clean ( ) { temp_tensor.clear(); subnetwork.clean(); } template friend void serialize(const add_loss_layer& item, std::ostream& out); template friend void deserialize(add_loss_layer& item, std::istream& in); friend std::ostream& operator<< (std::ostream& out, const add_loss_layer& item) { int min_length = 0; item.print(out, 0, min_length); return out; } void print (std::ostream& out, unsigned long idx, int& min_length) const { out << "layer<" << idx << ">\t" << loss_details() << "\n"; subnet().print(out, idx+1, min_length); } private: void swap(add_loss_layer& item) { std::swap(loss, item.loss); std::swap(subnetwork, item.subnetwork); } loss_details_type loss; subnet_type subnetwork; // These two objects don't logically contribute to the state of this object. They // are here to prevent them from being reallocated over and over. output_label_type temp_label; resizable_tensor temp_tensor; }; template void serialize(const add_loss_layer& item, std::ostream& out) { int version = 1; serialize(version, out); serialize(item.loss, out); serialize(item.subnetwork, out); } template void deserialize(add_loss_layer& item, std::istream& in) { int version = 0; deserialize(version, in); if (version != 1) throw serialization_error("Unexpected version found while deserializing dlib::add_loss_layer."); deserialize(item.loss, in); deserialize(item.subnetwork, in); } template struct is_loss_layer_type> : std::true_type {}; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- namespace impl { template struct layer_helper { static_assert(i < T::num_layers, "Call to layer() attempted to access non-existing layer in neural network."); static T& makeT(); // If you get error here mentioning lack of member "subnet" in "dlib::input<...>", // then likely your "dlib::layer<...>" invocation wasn't able to find requested layer. // This could happen for instance when trying to use skip layer for non-existing tag. using next_type = typename std::remove_reference::type; using type = typename layer_helper::type; static type& layer(T& n) { return layer_helper::layer(n.subnet()); } }; template < unsigned int i, size_t N, template class L, typename S > struct layer_helper, typename std::enable_if<(i!=0&&i>=repeat::layers_in_repeated_group)>::type> { const static size_t layers_in_repeated_group = repeat::layers_in_repeated_group; static repeat& makeT(); using next_type = typename std::remove_reference::type; using type = typename layer_helper::type; static type& layer(repeat& n) { return layer_helper::layer(n.subnet()); } }; template < unsigned int i, size_t N, template class L, typename S > struct layer_helper, typename std::enable_if<(i!=0&&i::layers_in_repeated_group)>::type> { const static size_t layers_in_each_group = repeat::layers_in_each_group; typedef typename repeat::repeated_layer_type repeated_layer_type; using next_type = repeated_layer_type; using type = typename layer_helper::type; static type& layer(repeat& n) { return layer_helper::layer(n.get_repeated_layer(i/layers_in_each_group)); } }; template < size_t N, template class L, typename S > struct layer_helper<0,repeat, void> { typedef typename repeat::repeated_layer_type repeated_layer_type; using type = repeated_layer_type; static type& layer(repeat& n) { return n.get_repeated_layer(0); } }; template < unsigned int i, size_t N, template class L, typename S > struct layer_helper, typename std::enable_if<(i!=0&&i>=repeat::layers_in_repeated_group)>::type> { const static size_t layers_in_repeated_group = repeat::layers_in_repeated_group; static const repeat& makeT(); using next_type = const typename std::remove_reference::type; using type = const typename layer_helper::type; static type& layer(const repeat& n) { return layer_helper::layer(n.subnet()); } }; template < unsigned int i, size_t N, template class L, typename S > struct layer_helper, typename std::enable_if<(i!=0&&i::layers_in_repeated_group)>::type> { const static size_t layers_in_each_group = repeat::layers_in_each_group; typedef typename repeat::repeated_layer_type repeated_layer_type; using next_type = const repeated_layer_type; using type = const typename layer_helper::type; static type& layer(const repeat& n) { return layer_helper::layer(n.get_repeated_layer(i/layers_in_each_group)); } }; template < size_t N, template class L, typename S > struct layer_helper<0,const repeat, void> { typedef typename repeat::repeated_layer_type repeated_layer_type; using type = const repeated_layer_type; static type& layer(const repeat& n) { return n.get_repeated_layer(0); } }; template struct layer_helper<0,T,void> { using type = T; static type& layer(T& n) { return n; } }; template class Match, typename T, unsigned int i, typename enabled = void> struct layer_helper_match { static T& makeT(); using next_type = typename std::remove_reference::type; using type = typename layer_helper_match::type; static type& layer(T& n) { return layer_helper_match::layer(n.subnet()); } }; // This overload catches add_layer and add_loss_layer templates. template class Match, typename T, unsigned int i> struct layer_helper_match>::value>::type> { using type = typename layer_helper::type; static type& layer(T& n) { return layer_helper::layer(n); } }; // This overload catches input templates. template class Match, typename T, unsigned int i> struct layer_helper_match>::value>::type> { using type = typename layer_helper::type; static type& layer(T& n) { return layer_helper::layer(n); } }; // This overload catches subnet_wrapper templates. template class Match, typename T, unsigned int i> struct layer_helper_match>::value>::type> { using type = typename layer_helper::type; static type& layer(T& n) { return layer_helper::layer(n); } }; } template typename impl::layer_helper::type& layer (T& n) { return impl::layer_helper::layer(n); } template class Match, typename T> typename impl::layer_helper_match::type& layer (T& n) { return impl::layer_helper_match::layer(n); } template class Match, unsigned int i, typename T> typename impl::layer_helper_match::type& layer (T& n) { return impl::layer_helper_match::layer(n); } // ---------------------------------------------------------------------------------------- template typename net_type::input_layer_type& input_layer ( net_type& net ) { return net.input_layer(); } template const typename net_type::input_layer_type& input_layer ( const net_type& net ) { return net.input_layer(); } // ---------------------------------------------------------------------------------------- template class TAG_TYPE, typename SUBNET> class add_skip_layer { public: typedef SUBNET subnet_type; typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_layer_type input_layer_type; typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper. const static size_t num_layers = subnet_type::num_layers + 1; const static size_t num_computational_layers = subnet_type::num_computational_layers; const static unsigned long id = tag_id::id; add_skip_layer() {}; add_skip_layer(const add_skip_layer&) = default; add_skip_layer(add_skip_layer&&) = default; add_skip_layer& operator=(add_skip_layer&&) = default; add_skip_layer& operator=(const add_skip_layer&) = default; template add_skip_layer( const add_skip_layer& item ) : subnetwork(item.subnet()) {} template add_skip_layer( T ...args ) : subnetwork(std::move(args)...) { } template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { subnetwork.to_tensor(ibegin,iend,data); } template const tensor& operator() ( forward_iterator ibegin, forward_iterator iend ) { subnetwork(ibegin,iend); return layer(subnetwork).get_output(); } const tensor& operator() (const input_type& x) { subnetwork(x); return layer(subnetwork).get_output(); } const tensor& forward(const tensor& x) { subnetwork.forward(x); return layer(subnetwork).get_output(); } const tensor& get_output() const { return layer(subnetwork).get_output(); } tensor& get_gradient_input() { return layer(subnetwork).get_gradient_input(); } const tensor& get_final_data_gradient( ) const { return subnetwork.get_final_data_gradient(); } void back_propagate_error( const tensor& x, zero_gradients zero_grads = zero_gradients::yes ) { subnetwork.back_propagate_error(x, zero_grads); } template void update_parameters(sstack solvers, double learning_rate) { subnetwork.update_parameters(solvers, learning_rate); } template void update_parameters(std::vector& solvers, double learning_rate) { update_parameters(make_sstack(solvers), learning_rate); } const tensor& get_parameter_gradient( ) const { return params_grad; } tensor& get_parameter_gradient ( ) { return params_grad; } const subnet_type& subnet() const { return subnetwork; } subnet_type& subnet() { return subnetwork; } const input_layer_type& input_layer() const { return subnet().input_layer(); } input_layer_type& input_layer() { return subnet().input_layer(); } unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); } void set_gradient_inputs_to_zero() { subnetwork.set_gradient_inputs_to_zero(); } void clean() { subnetwork.clean(); } friend void serialize(const add_skip_layer& item, std::ostream& out) { int version = 1; serialize(version, out); serialize(item.subnetwork, out); } friend void deserialize(add_skip_layer& item, std::istream& in) { int version = 0; deserialize(version, in); if (version != 1) throw serialization_error("Unexpected version found while deserializing dlib::add_skip_layer."); deserialize(item.subnetwork, in); } friend std::ostream& operator<< (std::ostream& out, const add_skip_layer& item) { int min_length = 0; item.print(out, 0, min_length); return out; } void print (std::ostream& out, unsigned long idx, int& min_length) const { out << "layer<" << idx << ">\t"< friend class add_layer; template friend class dimpl::subnet_wrapper; template friend class add_tag_layer; template class T, typename U> friend class add_skip_layer; template class L, typename S> friend class repeat; bool this_layer_requires_forward_output( ) { return layer(subnetwork).this_layer_requires_forward_output(); } void disable_output_and_gradient_getters ( ) { layer(subnetwork).disable_output_and_gradient_getters(); } tensor& private_get_output() const { return layer(subnetwork).private_get_output(); } tensor& private_get_gradient_input() { return layer(subnetwork).private_get_gradient_input(); } subnet_type subnetwork; // This member doesn't logically contribute to the state of the object since it is // always empty. It's just here so we can have the get_parameter_gradient() methods // which have to return something. So they return this empty tensor. resizable_tensor params_grad; }; template class T, typename U> struct is_nonloss_layer_type> : std::true_type {}; template using tag1 = add_tag_layer< 1, SUBNET>; template using tag2 = add_tag_layer< 2, SUBNET>; template using tag3 = add_tag_layer< 3, SUBNET>; template using tag4 = add_tag_layer< 4, SUBNET>; template using tag5 = add_tag_layer< 5, SUBNET>; template using tag6 = add_tag_layer< 6, SUBNET>; template using tag7 = add_tag_layer< 7, SUBNET>; template using tag8 = add_tag_layer< 8, SUBNET>; template using tag9 = add_tag_layer< 9, SUBNET>; template using tag10 = add_tag_layer<10, SUBNET>; template using skip1 = add_skip_layer< tag1, SUBNET>; template using skip2 = add_skip_layer< tag2, SUBNET>; template using skip3 = add_skip_layer< tag3, SUBNET>; template using skip4 = add_skip_layer< tag4, SUBNET>; template using skip5 = add_skip_layer< tag5, SUBNET>; template using skip6 = add_skip_layer< tag6, SUBNET>; template using skip7 = add_skip_layer< tag7, SUBNET>; template using skip8 = add_skip_layer< tag8, SUBNET>; template using skip9 = add_skip_layer< tag9, SUBNET>; template using skip10 = add_skip_layer; // ---------------------------------------------------------------------------------------- namespace timpl { inline void fill_with_gassuan_random_numbers ( tensor& t, dlib::rand& rnd, double sigma = 1 ) { float* data = t.host(); for (size_t i = 0; i < t.size(); ++i) data[i] = rnd.get_random_gaussian()*sigma; } class test_layer_subnet { public: test_layer_subnet ( dlib::rand& rnd_ ) : rnd(rnd_) { // Output and gradient_input have to have the same dimensions in each // layer. const long num_samples = rnd.get_random_32bit_number()%4+3; const long k = rnd.get_random_32bit_number()%4+2; const long nr = ((rnd.get_random_32bit_number()%4)/2)*2+2; const long nc = ((rnd.get_random_32bit_number()%4)/2)*2+2; output.set_size(num_samples, k, nr, nc); gradient_input.set_size(num_samples, k, nr, nc); // Use a non-zero initial gradient to make sure the layers add to it // rather than assign and blow away the initial value. fill_with_gassuan_random_numbers(gradient_input, rnd, 0.01); fill_with_gassuan_random_numbers(output, rnd); } tensor& get_mutable_output() { return output; } const tensor& get_output() const { return output; } const tensor& private_get_output() const { return get_output(); } const test_layer_subnet& subnet() const { init_sub(); return *subnetwork; } tensor& get_gradient_input() { return gradient_input; } tensor& private_get_gradient_input() { return get_gradient_input(); } test_layer_subnet& subnet() { init_sub(); return *subnetwork; } unsigned long count_outputs() const { if (subnetwork) return subnetwork->count_outputs() + output.size(); else return output.size(); } float& get_output_element(unsigned long i) { if (i < output.size()) return output.host()[i]; else return subnet().get_output_element(i-output.size()); } float get_gradient_input_element(unsigned long i) const { if (i < gradient_input.size()) return gradient_input.host()[i]; else return subnet().get_gradient_input_element(i-gradient_input.size()); } private: // We lazily initialize sub-layers as needed when someone tries to call // subnet() void init_sub() const { if (!subnetwork) subnetwork.reset(new test_layer_subnet(rnd)); } dlib::rand& rnd; mutable std::unique_ptr subnetwork; resizable_tensor output; resizable_tensor gradient_input; }; } struct layer_test_results { layer_test_results() : was_good(true) {} explicit layer_test_results(const std::string& l) : log(l),was_good(false) {} std::string log; bool was_good; operator bool() const { return was_good; } }; inline std::ostream& operator<< (std::ostream& out, const layer_test_results& item) { out << item.log; return out; } template < typename layer_details_type > layer_test_results impl_test_layer ( layer_details_type l, const float base_eps ) { using namespace timpl; // Do some setup running_stats rs_data, rs_params; dlib::rand rnd; std::ostringstream sout; for (int iter = 0; iter < 10; ++iter) { test_layer_subnet subnetwork(rnd); resizable_tensor output, out2, out3; // Run setup() and forward() as well to make sure any calls to subnet() have // happened before we start assuming we know how many data elements there are // (since we do a lazy layer creation thing based on calls to subnet() inside // test_layer_subnet). l.setup(subnetwork); impl::call_layer_forward(l, subnetwork, output); resizable_tensor input_grad; input_grad.copy_size(output); fill_with_gassuan_random_numbers(input_grad, rnd); // The f() we are computing gradients of is this thing. It's value at the current // parameter and data values is: //sout << "f(data,params): " << dot(output, input_grad) << std::endl; // We are going to save a copy of the subnetwork.get_gradient_input() data before we do // backpropagation since the backward() function is supposed to *add* to the // gradients rather than overwrite them. We will use this saved data to check if // that is the case. const unsigned long num_data_inputs = subnetwork.count_outputs(); std::vector initial_gradient_input(num_data_inputs); for (unsigned long i = 0; i < num_data_inputs; ++i) initial_gradient_input[i] = subnetwork.get_gradient_input_element(i); // Now tell the layer to compute all the gradients. In the rest of this function // we will just be checking that these gradients were computed correctly by // comparing them to a central differences approximation. resizable_tensor params_grad; params_grad.copy_size(l.get_layer_params()); // But first, set the params grad to something crazy so that it's very obvious if // it doesn't get fully assigned. params_grad = std::numeric_limits::infinity(); impl::call_layer_backward(l, output, input_grad, subnetwork, params_grad); static_assert(impl::is_inplace_layer(l, subnetwork) == impl::has_inplace_backward(l, subnetwork), "Layer not defined correctly. forward and backward methods must either both be in-place or both out-of-place. "); // Make sure the outputs of forward() and backward() are the same when they are run // in in-place mode. if (impl::is_inplace_layer(l, subnetwork)) { test_layer_subnet subnetwork2(rnd); layer_details_type ll(l); ll.setup(subnetwork2); resizable_tensor ip_out; impl::call_layer_forward(ll, subnetwork2, ip_out); impl::call_layer_forward(ll, subnetwork2, subnetwork2.get_mutable_output()); const auto forward_error = max(abs(mat(ip_out) - mat(subnetwork2.get_output()))); if (forward_error > 0.00001) { sout << "This layer is supposed to support in-place computations but the output of forward_inplace()\n"; sout << "changes when invoked in-place vs. out-of-place. The error was: " << forward_error << std::endl; return layer_test_results(sout.str()); } resizable_tensor params_grad; params_grad.copy_size(ll.get_layer_params()); params_grad = std::numeric_limits::infinity(); resizable_tensor input_grad; input_grad.copy_size(ip_out); fill_with_gassuan_random_numbers(input_grad, rnd); resizable_tensor params_grad1, params_grad2, data_grad1, data_grad2; params_grad1 = params_grad; params_grad2 = params_grad; // Now call backward() and make sure it works as well. Recall that when an // in-place layer works in-place it assigns to it's outputs but when it's // not running in-place it adds. So we initialize to a non-zero value to // check that this is the behavior that really executes. subnetwork2.get_gradient_input() = 9; impl::call_layer_backward(ll, ip_out, input_grad, subnetwork2, params_grad1); data_grad1 = subnetwork2.get_gradient_input(); subnetwork2.get_gradient_input() = mat(input_grad); impl::call_layer_backward(ll, ip_out, subnetwork2.get_gradient_input(), subnetwork2, params_grad2); data_grad2 = subnetwork2.get_gradient_input(); if (params_grad.size() != 0) { const auto backward_param_error = max(abs(mat(params_grad1) - mat(params_grad2))); if (backward_param_error > 0.00001) { sout << "This layer is supposed to support in-place computations but the output of backward_inplace()\n"; sout << "changes when invoked in-place vs. out-of-place. The error was: " << backward_param_error << std::endl; return layer_test_results(sout.str()); } } const auto backward_data_error = max(abs(mat(data_grad1)-9 - mat(data_grad2))); if (backward_data_error > 0.00001) { sout << "This layer is supposed to support in-place computations but the output of backward_inplace()\n"; sout << "changes when invoked in-place vs. out-of-place. The error was: " << backward_data_error << std::endl; return layer_test_results(sout.str()); } } // ================================================================== // first validate the way the parameter gradients are computed for (unsigned long i = 0; i < params_grad.size(); ++i) { layer_details_type l1(l); float eps = l1.get_layer_params().host()[i]*base_eps; if (eps == 0) eps = base_eps; const float oldval = l1.get_layer_params().host()[i]; l1.get_layer_params().host()[i] = oldval+eps; impl::call_layer_forward(l1, subnetwork, out2); l1.get_layer_params().host()[i] = oldval-eps; impl::call_layer_forward(l1, subnetwork, out3); l1.get_layer_params().host()[i] = oldval; // Compute a reference derivative via a central differences approximation and // compare it to the one output by the layer and make sure they match. double reference_derivative = (dot(out2,input_grad)-dot(out3, input_grad))/(2*eps); double output_derivative = params_grad.host()[i]; double relative_error; if (reference_derivative*output_derivative != 0) relative_error = (reference_derivative - output_derivative)/(reference_derivative); else relative_error = (reference_derivative - output_derivative); double absolute_error = (reference_derivative - output_derivative); rs_params.add(std::abs(relative_error)); if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006) { sout << "Gradient error in parameter #" << i <<". Relative error: "<< relative_error << std::endl; sout << "expected derivative: " << reference_derivative << std::endl; sout << "output derivative: " << output_derivative << std::endl; sout << "iteration: " << iter << std::endl; return layer_test_results(sout.str()); } } // ================================================================== // now validate the data gradients for (unsigned long i = 0; i < num_data_inputs; ++i) { const float oldval = subnetwork.get_output_element(i); float eps = oldval*base_eps; if (eps == 0) eps = base_eps; subnetwork.get_output_element(i) = oldval+eps; impl::call_layer_forward(l, subnetwork, out2); subnetwork.get_output_element(i) = oldval-eps; impl::call_layer_forward(l, subnetwork, out3); subnetwork.get_output_element(i) = oldval; // Compute a reference derivative via a central differences approximation and // compare it to the one output by the layer and make sure they match. double reference_derivative = (dot(out2,input_grad)-dot(out3, input_grad))/(2*eps); double output_derivative = subnetwork.get_gradient_input_element(i); output_derivative -= initial_gradient_input[i]; double relative_error; if (reference_derivative*output_derivative != 0) relative_error = (reference_derivative - output_derivative)/(reference_derivative); else relative_error = (reference_derivative - output_derivative); double absolute_error = (reference_derivative - output_derivative); rs_data.add(std::abs(relative_error)); if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006) { sout << "Gradient error in data variable #" << i <<". Relative error: "<< relative_error << std::endl; sout << "expected derivative: " << reference_derivative << std::endl; sout << "output derivative: " << output_derivative << std::endl; sout << "iteration: " << iter << std::endl; return layer_test_results(sout.str()); } } } // end for (int iter = 0; iter < 10; ++iter) if (rs_params.mean() > 0.003) { sout << "Average parameter gradient error is somewhat large at: "<< rs_params.mean() << std::endl; return layer_test_results(sout.str()); } if (rs_data.mean() > 0.003) { sout << "Average data gradient error is somewhat large at: "<< rs_data.mean() << std::endl; return layer_test_results(sout.str()); } return layer_test_results(); } template < typename layer_details_type > layer_test_results test_layer ( layer_details_type l ) { // Try a few different derivative step sizes to see if any work. for (float base_eps = 0.0001; base_eps < 0.1; base_eps *= 2) { auto result = impl_test_layer(l, base_eps); if (result) return result; } // However, if none of the step sizes worked then try this one and probably result // in returning an error. return impl_test_layer(l, 0.01); } // ---------------------------------------------------------------------------------------- namespace impl { template struct vl_loop { template < typename net_type, typename visitor > static void visit( net_type& net, visitor&& v ) { // Call whatever version of the visitor the user provided. call_if_valid(v, i, layer(net)); call_if_valid(v, layer(net)); vl_loop::visit(net,v); } }; template struct vl_loop { template < typename net_type, typename visitor > static void visit( net_type&, visitor&& ) { // Base case of recursion. Don't do anything. } }; template struct vl_loop_backwards { template < typename net_type, typename visitor > static void visit( net_type& net, visitor&& v ) { vl_loop_backwards::visit(net,v); // Call whatever version of the visitor the user provided. call_if_valid(v, i, layer(net)); call_if_valid(v, layer(net)); } }; template struct vl_loop_backwards { template < typename net_type, typename visitor > static void visit( net_type&, visitor&& ) { // Base case of recursion. Don't do anything. } }; } template < typename net_type, typename visitor > void visit_layers( net_type& net, visitor v ) { impl::vl_loop<0, net_type::num_layers>::visit(net, v); } template < typename net_type, typename visitor > void visit_layers_backwards( net_type& net, visitor v ) { impl::vl_loop_backwards<0, net_type::num_layers>::visit(net, v); } template < size_t begin, size_t end, typename net_type, typename visitor > void visit_layers_range( net_type& net, visitor v ) { static_assert(begin <= end, "Invalid range"); static_assert(end <= net_type::num_layers, "Invalid range"); impl::vl_loop::visit(net, v); } template < size_t begin, size_t end, typename net_type, typename visitor > void visit_layers_backwards_range( net_type& net, visitor v ) { static_assert(begin <= end, "Invalid range"); static_assert(end <= net_type::num_layers, "Invalid range"); impl::vl_loop_backwards::visit(net, v); } // ---------------------------------------------------------------------------------------- namespace impl { template struct vl_until_tag { template < typename net_type, typename next_net_type, typename visitor > static void visit( net_type& net, next_net_type& next_net, visitor&& v ) { call_if_valid(v, next_net); vl_until_tag::visit(net,layer(net),v); } template < typename net_type, typename SUBNET, typename visitor > static void visit( net_type&, const add_tag_layer& next_net, visitor&& v ) { call_if_valid(v, next_net); } template < typename net_type, typename SUBNET, typename visitor > static void visit( net_type&, add_tag_layer& next_net, visitor&& v ) { call_if_valid(v, next_net); } }; } template < unsigned long tag_id, typename net_type, typename visitor > void visit_layers_until_tag( net_type& net, visitor v ) { impl::vl_until_tag<0,tag_id>::visit(net, net, v); } // ---------------------------------------------------------------------------------------- namespace impl { template < typename visitor > class visitor_computational_layer { public: explicit visitor_computational_layer(visitor& v) : v_(v) {} template void do_visit(size_t idx, layer& l) const { // Call whatever version of the visitor the user provided. call_if_valid(v_, idx, l.layer_details()); call_if_valid(v_, l.layer_details()); } // const case template void operator()(size_t idx, const add_layer& l) const { do_visit(idx, l); } // non-const cast template void operator()(size_t idx, add_layer& l) const { do_visit(idx, l); } private: visitor& v_; }; } template < typename net_type, typename visitor > void visit_computational_layers( net_type& net, visitor v ) { visit_layers(net, impl::visitor_computational_layer(v)); } template < size_t begin, size_t end, typename net_type, typename visitor > void visit_computational_layers_range( net_type& net, visitor v ) { visit_layers_range(net, impl::visitor_computational_layer(v)); } // ---------------------------------------------------------------------------------------- namespace impl { template < typename visitor > class visit_layer_parameters { public: explicit visit_layer_parameters(visitor& v) : v_(v) {} template void operator()(layer& l) { // Call whatever version of the visitor the user provided. const bool visitor_called = call_if_valid(v_, computational_layer_idx, l.get_layer_params()) || call_if_valid(v_, l.get_layer_params()); DLIB_CASSERT(visitor_called, "A visitor function with an incorrect signature was given to visit_layer_parameters()"); ++computational_layer_idx; } private: size_t computational_layer_idx = 0; visitor& v_; }; } template < typename net_type, typename visitor > void visit_layer_parameters( net_type& net, visitor v ) { visit_computational_layers(net, impl::visit_layer_parameters(v)); } // ---------------------------------------------------------------------------------------- namespace impl { template < typename visitor > class visit_layer_parameter_gradients { public: explicit visit_layer_parameter_gradients(visitor& v) : v_(v) {} template void do_visit(layer& l) { // Call whatever version of the visitor the user provided. const bool visitor_called = call_if_valid(v_, computational_layer_idx, l.get_parameter_gradient()) || call_if_valid(v_, l.get_parameter_gradient()); DLIB_CASSERT(visitor_called, "A visitor function with an incorrect signature was given to visit_layer_parameter_gradients()"); ++computational_layer_idx; } // const version template void operator()(const add_layer& l) { do_visit(l); } // non-const version template void operator()(add_layer& l) { do_visit(l); } private: size_t computational_layer_idx = 0; visitor& v_; }; } template < typename net_type, typename visitor > void visit_layer_parameter_gradients( net_type& net, visitor v ) { visit_layers(net, impl::visit_layer_parameter_gradients(v)); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_DNn_CORE_H_ ================================================ FILE: dlib/dnn/core_abstract.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_DNn_CORE_ABSTRACT_H_ #ifdef DLIB_DNn_CORE_ABSTRACT_H_ #include "../cuda/tensor_abstract.h" #include #include #include #include #include "../rand.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename... T > auto tuple_tail( const std::tuple& item ); /*! ensures - returns a tuple that contains everything in item except for tuple_head(item). The items will be in the same order as they are in item, just without tuple_head(item). - This function will correctly handle nested tuples. !*/ template auto tuple_head ( const std::tuple& item ); /*! ensures - returns a copy of the first thing in the tuple that isn't a std::tuple. Essentially, this function calls std::get<0>() recursively on item until a non-std::tuple object is found. !*/ // ---------------------------------------------------------------------------------------- template double get_learning_rate_multiplier( const T& obj ); /*! ensures - if (obj has a get_learning_rate_multiplier() member function) then - returns obj.get_learning_rate_multiplier() - else - returns 1 !*/ template void set_learning_rate_multiplier( T& obj, double learning_rate_multiplier ); /*! requires - learning_rate_multiplier >= 0 ensures - if (obj has a set_learning_rate_multiplier() member function) then - calls obj.set_learning_rate_multiplier(learning_rate_multiplier) - else - does nothing !*/ // ---------------------------------------------------------------------------------------- template double get_bias_learning_rate_multiplier( const T& obj ); /*! ensures - if (obj has a get_bias_learning_rate_multiplier() member function) then - returns obj.get_bias_learning_rate_multiplier() - else - returns 1 !*/ template void set_bias_learning_rate_multiplier( T& obj, double bias_learning_rate_multiplier ); /*! requires - bias_learning_rate_multiplier >= 0 ensures - if (obj has a set_bias_learning_rate_multiplier() member function) then - calls obj.set_bias_learning_rate_multiplier(bias_learning_rate_multiplier) - else - does nothing !*/ // ---------------------------------------------------------------------------------------- template double get_weight_decay_multiplier( const T& obj ); /*! ensures - if (obj has a get_weight_decay_multiplier() member function) then - returns obj.get_weight_decay_multiplier() - else - returns 1 !*/ template void set_weight_decay_multiplier( T& obj, double weight_decay_multiplier ); /*! requires - weight_decay_multiplier >= 0 ensures - if (obj has a set_weight_decay_multiplier() member function) then - calls obj.set_weight_decay_multiplier(weight_decay_multiplier) - else - does nothing !*/ // ---------------------------------------------------------------------------------------- template double get_bias_weight_decay_multiplier( const T& obj ); /*! ensures - if (obj has a get_bias_weight_decay_multiplier() member function) then - returns obj.get_bias_weight_decay_multiplier() - else - returns 1 !*/ template void set_bias_weight_decay_multiplier( T& obj, double bias_weight_decay_multiplier ); /*! requires: - bias_weight_decay_multiplier >= 0 ensures - if (obj has a set_bias_weight_decay_multiplier() member function) then - calls obj.set_bias_weight_decay_multiplier(bias_weight_decay_multiplier) - else - does nothing !*/ // ---------------------------------------------------------------------------------------- template void disable_bias( T& obj ); /*! ensures - if (obj has a disable_bias() member function) then - calls obj.disable_bias() - else - does nothing !*/ // ---------------------------------------------------------------------------------------- bool dnn_prefer_fastest_algorithms( ); /*! ensures - If dlib should prefer to use fast algorithms rather than ones that use less RAM then this function returns true and false otherwise. - On program startup this function will default to true. !*/ void set_dnn_prefer_fastest_algorithms( ); /*! ensures - #dnn_prefer_fastest_algorithms() == true !*/ void set_dnn_prefer_smallest_algorithms( ); /*! ensures - #dnn_prefer_fastest_algorithms() == false !*/ // ---------------------------------------------------------------------------------------- template < typename T > class sstack { /*! WHAT THIS OBJECT REPRESENTS This is a basic stack of T objects. It contains no data itself but simply points to a memory range of T object and allows you to access that block of T objects as a stack. !*/ public: typedef T value_type; sstack() = delete; sstack ( T* data, size_t s ); /*! ensures - #size() == s - #top() == *data - #pop(i).top() == data[i] !*/ const T& top( ) const; /*! requires - size() != 0 ensures - returns the top element of the stack. !*/ T& top( ); /*! requires - size() != 0 ensures - returns the top element of the stack. !*/ size_t size( ) const; /*! ensures - returns the number of elements in this stack. !*/ sstack pop( size_t num = 1 ); /*! requires - num <= size() ensures - returns a reference to the sub-stack S such that: - S.size() == size()-num. - S.top() is num elements down the stack. !*/ }; template < typename T > sstack make_sstack( std::vector& item ) { return sstack(item.data(), item.size()); } /*! ensures - returns a sstack that sits on top of the given std::vector. !*/ // ---------------------------------------------------------------------------------------- enum class zero_gradients : uint8_t { no, yes }; // ---------------------------------------------------------------------------------------- template < typename LAYER_DETAILS, typename SUBNET > class add_layer { /*! REQUIREMENTS ON LAYER_DETAILS - Must be a type that implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined in layers_abstract.h REQUIREMENTS ON SUBNET - One of the following must be true: - SUBNET implements the EXAMPLE_INPUT_LAYER interface defined in input_abstract.h. - SUBNET is an add_layer object. - SUBNET is an add_tag_layer object. - SUBNET is an add_skip_layer object. - SUBNET is a repeat object. WHAT THIS OBJECT REPRESENTS This object represents a deep neural network. In particular, it is a tool for adding another layer on top of the neural network of type SUBNET, which is specified as a template argument. The specific layer added is defined by the LAYER_DETAILS details template argument. !*/ public: typedef LAYER_DETAILS layer_details_type; typedef SUBNET subnet_type; typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_layer_type input_layer_type; // num_computational_layers will always give the number of layers in the network // that transform tensors (i.e. layers defined by something that implements the // EXAMPLE_COMPUTATIONAL_LAYER_ interface). This is all the layers except for // loss, tag, and skip layers. const static size_t num_computational_layers = subnet_type::num_computational_layers + 1; // num_layers counts all the layers in the network regardless of their type. const static size_t num_layers = subnet_type::num_layers + 1; add_layer( ); /*! ensures - default constructs all the layers in this network. - #sample_expansion_factor() == 0 !*/ add_layer(const add_layer&) = default; add_layer(add_layer&&) = default; add_layer& operator=(add_layer&&) = default; add_layer& operator=(const add_layer&) = default; /*! ensures - this object is copyable and movable. !*/ template add_layer( const add_layer& item ); /*! ensures - This constructor allows you to copy neural network objects from one to another as long as their corresponding layers can be constructed from each other. - #layer_details() == layer_details_type(item.layer_details()) - #subnet() == subnet_type(item.subnet()) - #sample_expansion_factor() == item.sample_expansion_factor() !*/ template add_layer( const std::tuple& layer_det, T&& ...args ); /*! ensures - #layer_details() == layer_details_type(tuple_head(layer_det)) - #subnet() == subnet_type(tuple_tail(layer_det),args) - #sample_expansion_factor() == 0 !*/ template add_layer( const layer_details_type& layer_det, T&& ...args ); /*! ensures - #layer_details() == layer_details_type(layer_det) - #subnet() == subnet_type(args) - #sample_expansion_factor() == 0 !*/ template add_layer( T&& ...args ); /*! ensures - This version of the constructor is only called if layer_details_type can't be constructed from the first thing in args. In this case, the args are simply passed on to the sub layers in their entirety. - #layer_details() == layer_details_type() - #subnet() == subnet_type(args) - #sample_expansion_factor() == 0 !*/ template add_layer( layer_details_type&& layer_det, T&& ...args ); /*! ensures - #layer_details() == layer_det - #subnet() == subnet_type(args) - #sample_expansion_factor() == 0 !*/ template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const; /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 ensures - Converts the iterator range into a tensor and stores it into #data. - #data.num_samples()%distance(ibegin,iend) == 0. - #sample_expansion_factor() == #data.num_samples()/distance(ibegin,iend). - #sample_expansion_factor() > 0 - The data in the ith sample of #data corresponds to the input_type object *(ibegin+i/#sample_expansion_factor()). - Invokes data.async_copy_to_device() so that the data begins transferring to the GPU device, if present. - This function is implemented by calling the to_tensor() routine defined at the input layer of this network. !*/ unsigned int sample_expansion_factor ( ) const; /*! ensures - When to_tensor() is invoked on this network's input layer it converts N input objects into M samples, all stored inside a resizable_tensor. It is always the case that M is some integer multiple of N. sample_expansion_factor() returns the value of this multiplier. To be very specific, it is always true that M==I*N where I is some integer. This integer I is what is returned by sample_expansion_factor(). !*/ const subnet_type& subnet( ) const; /*! ensures - returns the immediate subnetwork of *this network. !*/ subnet_type& subnet( ); /*! ensures - returns the immediate subnetwork of *this network. !*/ const input_layer_type& input_layer( ) const; /*! ensures - returns the very first layer in *this network. It's equivalent to calling subnet() recursively until you get to the first layer. This means it will return the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined in input_abstract.h !*/ input_layer_type& input_layer( ); /*! ensures - returns the very first layer in *this network. It's equivalent to calling subnet() recursively until you get to the first layer. This means it will return the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined in input_abstract.h !*/ const layer_details_type& layer_details( ) const; /*! ensures - returns the layer_details_type instance that defines the behavior of the layer at the top of this network. I.e. returns the layer details that defines the behavior of the layer nearest to the network output rather than the input layer. !*/ layer_details_type& layer_details( ); /*! ensures - returns the layer_details_type instance that defines the behavior of the layer at the top of this network. I.e. returns the layer details that defines the behavior of the layer nearest to the network output rather than the input layer. !*/ template const tensor& operator() ( forward_iterator ibegin, forward_iterator iend ); /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 ensures - runs [ibegin,iend) through the network and returns the results. In particular, this function performs: to_tensor(ibegin,iend,temp_tensor); return forward(temp_tensor); - The return value from this function is also available in #get_output(). i.e. this function returns #get_output(). - have_same_dimensions(#get_gradient_input(), #get_output()) == true. - All elements of #get_gradient_input() are set to 0. i.e. calling this function clears out #get_gradient_input() and ensures it has the same dimensions as the most recent output. !*/ const tensor& operator() ( const input_type& x ); /*! ensures - runs a single x through the network and returns the output. I.e. returns (*this)(&x, &x+1); !*/ const tensor& forward( const tensor& x ); /*! requires - sample_expansion_factor() != 0 (i.e. to_tensor() must have been called to set sample_expansion_factor() to something non-zero.) - x.num_samples()%sample_expansion_factor() == 0 - x.num_samples() > 0 ensures - Runs x through the network and returns the results. In particular, this function performs the equivalent of: subnet().forward(x); if (this is the first time forward() has been called) then layer_details().setup(subnet()); layer_details().forward(subnet(), get_output()); - The return value from this function is also available in #get_output(). i.e. this function returns #get_output(). - have_same_dimensions(#get_gradient_input(), #get_output()) == true - All elements of #get_gradient_input() are set to 0. i.e. calling this function clears out #get_gradient_input() and ensures it has the same dimensions as the most recent output. !*/ const tensor& get_output( ) const; /*! ensures - returns the output for the last tensor that was run through the network. If nothing has been run through the network yet then returns an empty tensor. !*/ tensor& get_gradient_input( ); /*! ensures - returns the error gradient for this network. That is, this is the error gradient that this network will use to compute parameter gradients when back_propagate_error() is called. Therefore, when performing back propagation, layers that sit on top of this network layer write their back-propagated error gradients into get_gradient_input(). Or to put it another way, during back-propagation, layers take the contents of their get_gradient_input() and back-propagate it through themselves and store the result into their subnetwork's get_gradient_input(). This means you should consider get_gradient_input() as an input to the back_propagate_error() method. !*/ const tensor& get_final_data_gradient( ) const; /*! ensures - if back_propagate_error() has been called to back-propagate a gradient through this network then you can call get_final_data_gradient() to obtain the last data gradient computed. That is, this function returns the gradient of the network with respect to its inputs. - Note that there is only one "final data gradient" for an entire network, not one per layer, since there is only one input to the entire network. !*/ const tensor& get_parameter_gradient( ) const; /*! ensures - if back_propagate_error() has been called then you can call get_parameter_gradient() to find the gradient of this layer's parameters. When we update the parameters by calling update_parameters(), it will use the gradient in get_parameter_gradient() to perform the update. Therefore, you should consider get_parameter_gradient() as an input to update_parameters(). !*/ tensor& get_parameter_gradient ( ); /*! ensures - returns a non-const reference to the tensor returned by the above get_parameter_gradient() method. You could use this method to modify the parameter gradient in some way before invoking update_parameters(). !*/ void back_propagate_error( const tensor& x, zero_gradients zero_grads = zero_gradients::yes ); /*! requires - forward(x) was called to forward propagate x though the network. Moreover, this was the most recent call to forward() and x has not been subsequently modified in any way. - get_gradient_input() has been set equal to the gradient of this network's output with respect to some loss function. ensures - Back propagates the error gradient, get_gradient_input(), through this network and computes parameter and data gradients, via backpropagation. Specifically, this function populates get_final_data_gradient() and also, for each layer, the tensor returned by get_parameter_gradient(). - All elements of #get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes. - have_same_dimensions(#get_final_data_gradient(), x) == true. - have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true. - #get_final_data_gradient() contains the gradient of the network with respect to x. !*/ void back_propagate_error( const tensor& x, const tensor& gradient_input, zero_gradients zero_grads = zero_gradients::yes ); /*! requires - forward(x) was called to forward propagate x though the network. Moreover, this was the most recent call to forward() and x has not been subsequently modified in any way. - have_same_dimensions(gradient_input, get_output()) == true ensures - This function is identical to the version of back_propagate_error() defined immediately above except that it back-propagates gradient_input through the network instead of get_gradient_input(). Therefore, this version of back_propagate_error() is equivalent to performing: get_gradient_input() = gradient_input; back_propagate_error(x); Except that calling back_propagate_error(x,gradient_input) avoids the copy and is therefore slightly more efficient. - All elements of #get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes. - have_same_dimensions(#get_final_data_gradient(), x) == true. - have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true. - #get_final_data_gradient() contains the gradient of the network with respect to x. !*/ template void update_parameters( sstack solvers, double learning_rate ); /*! requires - solver_type is an implementation of the EXAMPLE_SOLVER interface defined in solvers_abstract.h - back_propagate_error() has been called. - The given solvers have only ever been used with this network. That is, if you want to call update_parameters() on some other neural network object then you must NOT reuse the same solvers object. - solvers.size() >= num_computational_layers - 0 < learning_rate <= 1 ensures - Updates all the parameters in the network. In particular, we pass each layer's parameter gradient (i.e. the tensor returned by the layer's get_parameter_gradient() member) through that layer's corresponding solver object. This produces a parameter delta vector which we add to the layer's parameters. - The solvers use the given learning rate. !*/ template void update_parameters(std::vector& solvers, double learning_rate) { update_parameters(make_sstack(solvers), learning_rate); } /*! Convenience method for calling update_parameters() !*/ void set_gradient_inputs_to_zero( ); /*! ensures - Sets all elements in all gradient inputs in the network to 0. That is, for each layer, we will have: - get_gradient_input() == 0 - Note that You only need to call this method if you manually called either - back_propagate_error - compute_parameter_gradients with the zero_grads parameter set to zero_gradients::no. - invokes subnet().set_gradient_inputs_to_zero() !*/ void clean( ); /*! ensures - Causes the network to forget about everything but its parameters. That is, for each layer we will have: - get_output().num_samples() == 0 - get_gradient_input().num_samples() == 0 However, running new input data though this network will still produce the same output it would have produced regardless of any calls to clean(). The purpose of clean() is to compact the network object prior to saving it to disk so that it takes up less space and the IO is quicker. - This also calls the .clean() method on any layer details objects that define a .clean() method. !*/ }; template std::ostream& operator<<(std::ostream& out, const add_layer& item); /*! prints the network architecture to the given output stream. !*/ template void serialize(const add_layer& item, std::ostream& out); template void deserialize(add_layer& item, std::istream& in); /*! provides serialization support !*/ // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- class no_label_type; template < typename LOSS_DETAILS, typename SUBNET > class add_loss_layer { /*! REQUIREMENTS ON LOSS_DETAILS - Must be a type that implements the EXAMPLE_LOSS_LAYER_ interface defined in loss_abstract.h REQUIREMENTS ON SUBNET - One of the following must be true: - SUBNET is an add_layer object. - SUBNET is an add_tag_layer object. - SUBNET is an add_skip_layer object. - SUBNET is a repeat object. WHAT THIS OBJECT REPRESENTS This object represents a deep neural network. In particular, it is a tool for adding a loss layer on top of the neural network of type SUBNET, which is specified as a template argument. The specific layer added is defined by the LOSS_DETAILS details template argument. Importantly, a loss layer is the last layer in a deep neural network. So once it is added you can't add any other layers of any type. !*/ public: typedef LOSS_DETAILS loss_details_type; typedef SUBNET subnet_type; typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_layer_type input_layer_type; const static size_t num_computational_layers = subnet_type::num_computational_layers; const static size_t num_layers = subnet_type::num_layers + 1; // If LOSS_DETAILS is an unsupervised loss then training_label_type==no_label_type. // Otherwise it is defined as follows: typedef typename LOSS_DETAILS::training_label_type training_label_type; // Similarly, if LOSS_DETAILS doesn't provide any output conversion then // output_label_type==no_label_type. typedef typename LOSS_DETAILS::output_label_type output_label_type; add_loss_layer() = default; /*! ensures - default constructs all the layers in this network. !*/ add_loss_layer(const add_loss_layer&) = default; add_loss_layer(add_loss_layer&&) = default; add_loss_layer& operator=(add_loss_layer&&) = default; add_loss_layer& operator=(const add_loss_layer&) = default; /*! ensures - this object is copyable and movable. !*/ template add_loss_layer( const add_loss_layer& item ); /*! ensures - This constructor allows you to copy neural network objects from one to another as long as their corresponding layers can be constructed from each other. - #loss_details() == loss_details_type(item.loss_details()) - #subnet() == subnet_type(item.subnet()) !*/ template add_loss_layer( const LOSS_DETAILS& layer_det, T&& ...args ); /*! ensures - #loss_details() == loss_details_type(layer_det) - #subnet() == subnet_type(args) !*/ template add_loss_layer( LOSS_DETAILS&& layer_det, T&& ...args ); /*! ensures - #loss_details() == loss_details_type(layer_det) - #subnet() == subnet_type(args) !*/ template add_loss_layer( T&& ...args ); /*! ensures - This version of the constructor is only called if loss_details_type can't be constructed from the first thing in args. In this case, the args are simply passed on to the sub layers in their entirety. - #loss_details() == loss_details_type() - #subnet() == subnet_type(args) !*/ const subnet_type& subnet( ) const; /*! ensures - returns the immediate subnetwork of *this network. !*/ subnet_type& subnet( ); /*! ensures - returns the immediate subnetwork of *this network. !*/ const input_layer_type& input_layer( ) const; /*! ensures - returns the very first layer in *this network. It's equivalent to calling subnet() recursively until you get to the first layer. This means it will return the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined in input_abstract.h !*/ input_layer_type& input_layer( ); /*! ensures - returns the very first layer in *this network. It's equivalent to calling subnet() recursively until you get to the first layer. This means it will return the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined in input_abstract.h !*/ const loss_details_type& loss_details( ) const; /*! ensures - returns the loss_details_type instance that defines the behavior of the loss layer used by this network. !*/ loss_details_type& loss_details( ); /*! ensures - returns the loss_details_type instance that defines the behavior of the loss layer used by this network. !*/ template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const; /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 ensures - Converts the iterator range into a tensor and stores it into #data. - #data.num_samples()%distance(ibegin,iend) == 0. - #sample_expansion_factor() == #data.num_samples()/distance(ibegin,iend). - #sample_expansion_factor() > 0 - The data in the ith sample of #data corresponds to the input_type object *(ibegin+i/sample_expansion_factor()). - Invokes data.async_copy_to_device() so that the data begins transferring to the GPU device, if present. - This function is implemented by calling the to_tensor() routine defined at the input layer of this network. !*/ unsigned int sample_expansion_factor ( ) const; /*! ensures - When to_tensor() is invoked on this network's input layer it converts N input objects into M samples, all stored inside a resizable_tensor. It is always the case that M is some integer multiple of N. sample_expansion_factor() returns the value of this multiplier. To be very specific, it is always true that M==I*N where I is some integer. This integer I is what is returned by sample_expansion_factor(). !*/ // ------------- const tensor& forward(const tensor& x ); /*! requires - sample_expansion_factor() != 0 (i.e. to_tensor() must have been called to set sample_expansion_factor() to something non-zero.) - x.num_samples()%sample_expansion_factor() == 0 - x.num_samples() > 0 ensures - Runs x through the network and returns the results as a tensor. In particular, this function just performs: return subnet().forward(x); So if you want to get the outputs as an output_label_type then call one of the methods below instead, like operator(). - The return value from this function is also available in #subnet().get_output(). i.e. this function returns #subnet().get_output(). - have_same_dimensions(#subnet().get_gradient_input(), #subnet().get_output()) == true - All elements of #subnet().get_gradient_input() are set to 0. i.e. calling this function clears out #subnet().get_gradient_input() and ensures it has the same dimensions as the most recent output. !*/ template void operator() ( const tensor& x, output_iterator obegin ); /*! requires - sample_expansion_factor() != 0 (i.e. to_tensor() must have been called to set sample_expansion_factor() to something non-zero.) - x.num_samples()%sample_expansion_factor() == 0 - x.num_samples() > 0 - obegin == iterator pointing to the start of a range of x.num_samples()/sample_expansion_factor() output_label_type elements. ensures - runs x through the network and writes the output to the range at obegin. - loss_details().to_label() is used to write the network output into obegin. !*/ template void operator() ( forward_iterator ibegin, forward_iterator iend, label_iterator obegin ); /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 - obegin == iterator pointing to the start of a range of std::distance(ibegin,iend) output_label_type elements. ensures - runs [ibegin,iend) through the network and writes the output to the range at obegin. - loss_details().to_label() is used to write the network output into obegin. !*/ // ------------- const output_label_type& operator() ( const input_type& x ); /*! ensures - runs a single object, x, through the network and returns the output. - loss_details().to_label() is used to convert the network output into a output_label_type. !*/ template std::vector operator() ( const iterable_type& data, size_t batch_size = 128 ); /*! requires - batch_size > 0 - data must have a .begin() and .end() that supply iterators over a sequence of input_type elements. E.g. data could have a type of std::vector ensures - runs all the objects in data through the network and returns their predicted labels. This means this function returns a vector V such that: - V.size() == data.size() - for all valid i: V[i] == the predicted label of data[i]. - Elements of data are run through the network in batches of batch_size items. Using a batch_size > 1 can be faster because it better exploits the available hardware parallelism. - loss_details().to_label() is used to convert the network output into a output_label_type. !*/ template const output_label_type& process ( const input_type& x, T&& ...args ); /*! ensures - This function is just like (*this)(x), i.e. it runs a single object, x, through the network and returns the output. But we additionally pass the given args to loss_details().to_label() as the 4th argument (or more, depending on how many things are in args) when converting the network output to an output_label_type. This is useful, for instance, with loss layers like loss_mmod_ which has an optional adjust_threshold argument to to_label() that adjusts the detection threshold. Therefore, for such networks you could call them like: net.process(some_image, -0.5), and -0.5 would be passed so the adjust_threshold argument of to_tensor(). !*/ template std::vector process_batch ( const iterable_type& data, size_t batch_size, T&& ...args ); /*! requires - batch_size > 0 - data must have a .begin() and .end() that supply iterators over a sequence of input_type elements. E.g. data could have a type of std::vector ensures - This function is just like (*this)(data,batch_size), i.e. it runs a bunch of objects through the network and returns the outputs. But we additionally pass the given args to loss_details().to_label() as the 4th argument (or more, depending on how many things are in args) when converting the network output to output_label_types. This is useful, for instance, with loss layers like loss_mmod_ which has an optional adjust_threshold argument to to_label() that adjusts the detection threshold. Therefore, for such networks you could call them like: net.process_batch(std::vector({some_image, another_image}), 128, -0.5), and -0.5 would be passed so the adjust_threshold argument of to_tensor(). !*/ // ------------- template double compute_loss ( const tensor& x, label_iterator lbegin ); /*! requires - sample_expansion_factor() != 0 (i.e. to_tensor() must have been called to set sample_expansion_factor() to something non-zero.) - x.num_samples()%sample_expansion_factor() == 0 - x.num_samples() > 0 - lbegin == iterator pointing to the start of a range of x.num_samples()/sample_expansion_factor() training_label_type elements. ensures - runs x through the network, compares the output to the expected output pointed to by lbegin, and returns the resulting loss. - for all valid k: - the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()). - This function does not update the network parameters. - For sub-layers that are immediate inputs into the loss layer, we also populate the sub-layer's get_gradient_input() tensor with the gradient of the loss with respect to the sub-layer's output. !*/ template double compute_loss ( forward_iterator ibegin, forward_iterator iend, label_iterator lbegin ); /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 - lbegin == iterator pointing to the start of a range of std::distance(ibegin,iend) training_label_type elements. ensures - runs [ibegin,iend) through the network, compares the output to the expected output pointed to by lbegin, and returns the resulting loss. - for all valid k: - the expected label of *(ibegin+k) is *(lbegin+k). - This function does not update the network parameters. - For sub-layers that are immediate inputs into the loss layer, we also populate the sub-layer's get_gradient_input() tensor with the gradient of the loss with respect to the sub-layer's output. !*/ // ------------- double compute_loss ( const tensor& x ); /*! requires - LOSS_DETAILS is an unsupervised loss. i.e. training_label_type==no_label_type. - sample_expansion_factor() != 0 (i.e. to_tensor() must have been called to set sample_expansion_factor() to something non-zero.) - x.num_samples()%sample_expansion_factor() == 0 - x.num_samples() > 0 ensures - runs x through the network and returns the resulting loss. - This function does not update the network parameters. - For sub-layers that are immediate inputs into the loss layer, we also populate the sub-layer's get_gradient_input() tensor with the gradient of the loss with respect to the sub-layer's output. !*/ template double compute_loss ( forward_iterator ibegin, forward_iterator iend, ); /*! requires - LOSS_DETAILS is an unsupervised loss. i.e. training_label_type==no_label_type. - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 ensures - runs [ibegin,iend) through the network and returns the resulting loss. - This function does not update the network parameters. - For sub-layers that are immediate inputs into the loss layer, we also populate the sub-layer's get_gradient_input() tensor with the gradient of the loss with respect to the sub-layer's output. !*/ // ------------- template double compute_parameter_gradients ( const tensor& x, label_iterator lbegin, zero_gradients zero_grads = zero_gradients::yes ); /*! requires - sample_expansion_factor() != 0 (i.e. to_tensor() must have been called to set sample_expansion_factor() to something non-zero.) - x.num_samples()%sample_expansion_factor() == 0 - x.num_samples() > 0 - lbegin == iterator pointing to the start of a range of x.num_samples()/sample_expansion_factor() training_label_type elements. ensures - runs x through the network, compares the output to the expected output pointed to by lbegin, and computes parameter and data gradients with respect to the loss, via backpropagation. Specifically, this function updates get_final_data_gradient() and also, for each layer, the tensor returned by get_parameter_gradient(). - All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes. - for all valid k: - the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()). - returns compute_loss(x,lbegin) !*/ template double compute_parameter_gradients ( forward_iterator ibegin, forward_iterator iend, label_iterator lbegin, zero_gradients zero_grads = zero_gradients::yes ); /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 - lbegin == iterator pointing to the start of a range of std::distance(ibegin,iend) training_label_type elements. ensures - runs [ibegin,iend) through the network, compares the output to the expected output pointed to by lbegin, and computes parameter and data gradients with respect to the loss, via backpropagation. Specifically, this function updates get_final_data_gradient() and also, for each layer, the tensor returned by get_parameter_gradient(). - All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes. - for all valid k: - the expected label of *(ibegin+k) is *(lbegin+k). - returns compute_loss(ibegin,iend,lbegin) !*/ double compute_parameter_gradients ( const tensor& x, zero_gradients zero_grads = zero_gradients::yes ); /*! requires - LOSS_DETAILS is an unsupervised loss. i.e. training_label_type==no_label_type. - sample_expansion_factor() != 0 (i.e. to_tensor() must have been called to set sample_expansion_factor() to something non-zero.) - x.num_samples()%sample_expansion_factor() == 0 - x.num_samples() > 0 ensures - runs x through the network and computes parameter and data gradients with respect to the loss, via backpropagation. Specifically, this function updates get_final_data_gradient() and also, for each layer, the tensor returned by get_parameter_gradient(). - All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes. - returns compute_loss(x) !*/ template double compute_parameter_gradients ( forward_iterator ibegin, forward_iterator iend, zero_gradients zero_grads = zero_gradients::yes ); /*! requires - LOSS_DETAILS is an unsupervised loss. i.e. training_label_type==no_label_type. - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 ensures - runs [ibegin,iend) through the network and computes parameter and data gradients with respect to the loss, via backpropagation. Specifically, this function updates get_final_data_gradient() and also, for each layer, the tensor returned by get_parameter_gradient(). - All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes. - returns compute_loss(ibegin,iend) !*/ template void update_parameters ( sstack solvers, double learning_rate ); /*! requires - solver_type is an implementation of the EXAMPLE_SOLVER interface defined in solvers_abstract.h - compute_parameter_gradients() has been called. - The given solvers have only ever been used with this network. That is, if you want to call update_parameters() on some other neural network object then you must NOT reuse the same solvers object. - solvers.size() >= num_computational_layers - 0 < learning_rate <= 1 ensures - Updates all the parameters in the network. In particular, we pass each layer's parameter gradient (i.e. the tensor returned by the layer's get_parameter_gradient() member) through that layer's corresponding solver object. This produces a parameter delta vector which we add to the layer's parameters. - The solvers use the given learning rate. !*/ template void update_parameters(std::vector& solvers, double learning_rate ) { update_parameters(make_sstack(solvers), learning_rate); } /*! Convenience method for calling update_parameters() !*/ void back_propagate_error( const tensor& x zero_gradients zero_grads = zero_gradients::yes ); /*! requires - forward(x) was called to forward propagate x though the network. Moreover, this was the most recent call to forward() and x has not been subsequently modified in any way. - subnet().get_gradient_input() has been set equal to the gradient of this network's output with respect to the loss function (generally this will be done by calling compute_loss()). ensures - Back propagates the error gradient, subnet().get_gradient_input(), through this network and computes parameter and data gradients, via backpropagation. Specifically, this function populates get_final_data_gradient() and also, for each layer, the tensor returned by get_parameter_gradient(). - All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes. - have_same_dimensions(#get_final_data_gradient(), x) == true. - #get_final_data_gradient() contains the gradient of the network with respect to x. !*/ void back_propagate_error( const tensor& x, const tensor& gradient_input ); /*! requires - forward(x) was called to forward propagate x though the network. Moreover, this was the most recent call to forward() and x has not been subsequently modified in any way. - have_same_dimensions(gradient_input, subnet().get_output()) == true ensures - This function is identical to the version of back_propagate_error() defined immediately above except that it back-propagates gradient_input through the network instead of subnet().get_gradient_input(). Therefore, this version of back_propagate_error() is equivalent to performing: subnet().get_gradient_input() = gradient_input; back_propagate_error(x); Except that calling back_propagate_error(x,gradient_input) avoids the copy and is therefore slightly more efficient. - All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes. - have_same_dimensions(#get_final_data_gradient(), x) == true. - #get_final_data_gradient() contains the gradient of the network with respect to x. !*/ const tensor& get_final_data_gradient( ) const; /*! ensures - if back_propagate_error() has been called to back-propagate a gradient through this network then you can call get_final_data_gradient() to obtain the last data gradient computed. That is, this function returns the gradient of the network with respect to its inputs. - Note that there is only one "final data gradient" for an entire network, not one per layer, since there is only one input to the entire network. !*/ void set_gradient_inputs_to_zero( ); /*! ensures - Sets all elements in all gradient inputs in the network to 0. - invokes subnet().set_gradient_inputs_to_zero() !*/ // ------------- void clean ( ); /*! ensures - Causes the network to forget about everything but its parameters. - invokes subnet().clean() !*/ }; template std::ostream& operator<<(std::ostream& out, const add_loss_layer& item); /*! prints the network architecture to the given output stream. !*/ template void serialize(const add_loss_layer& item, std::ostream& out); template void deserialize(add_loss_layer& item, std::istream& in); /*! provides serialization support !*/ // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template decorator_repeat_group repeat_group ( T&& ...args ); /*! ensures - Decorates a group of variables. This is essentially like std::make_tuple() except it's only purpose is to group variables together so they can be passed to the repeat object's constructor. !*/ template < size_t num, template class REPEATED_LAYER, typename SUBNET > class repeat { /*! REQUIREMENTS ON num - num > 0 REQUIREMENTS ON REPEATED_LAYER - REPEATED_LAYER must be a template that stacks more layers onto a deep neural network. For example, if net_type were a network without a loss layer, then it should be legal to create a deeper network with a type of REPEATED_LAYER. REQUIREMENTS ON SUBNET - One of the following must be true: - SUBNET is an add_layer object. - SUBNET is an add_tag_layer object. - SUBNET is an add_skip_layer object. - SUBNET is a repeat object. WHAT THIS OBJECT REPRESENTS This object adds more layers to a deep neural network. In particular, it adds REPEATED_LAYER on top of SUBNET num times. So for example, if num were 2 then repeat<2,REPEATED_LAYER,SUBNET> would create a network equivalent to REPEATED_LAYER>. Also, this object provides an interface identical to the one defined by the add_layer object except that we add the num_repetitions() and get_repeated_layer() methods. These additions are shown below along with some additional explanatory comments. !*/ public: typedef SUBNET subnet_type; typedef typename SUBNET::input_type input_type; typedef typename subnet_type::input_layer_type input_layer_type; const static size_t num_computational_layers = (REPEATED_LAYER::num_computational_layers-SUBNET::num_computational_layers)*num + SUBNET::num_computational_layers; const static size_t num_layers = (REPEATED_LAYER::num_layers-SUBNET::num_layers)*num + SUBNET::num_layers; typedef REPEATED_LAYER repeated_layer_type; template repeat( T arg1, U ...args2 ); /*! ensures - arg1 is used to initialize the num_repetitions() copies of REPEATED_LAYER inside this object. That is, all the REPEATED_LAYER elements are initialized identically by being given copies of arg1. - The rest of the arguments to the constructor, i.e. args2, are passed to SUBNET's constructor. !*/ template repeat( decorator_repeat_group&& arg1, U ...args2 ); /*! ensures - arg1 is used to initialize the num_repetitions() copies of REPEATED_LAYER inside this object. That is, all the REPEATED_LAYER elements are initialized identically by being given copies of an undecorated arg1. - The rest of the arguments to the constructor, i.e. args2, are passed to SUBNET's constructor. !*/ size_t num_repetitions ( ) const; /*! ensures - returns num (i.e. the number of times REPEATED_LAYER was stacked on top of SUBNET) !*/ const repeated_layer_type& get_repeated_layer ( size_t i ) const; /*! requires - i < num_repetitions() ensures - returns a reference to the i-th instance of REPEATED_LAYER. For example, get_repeated_layer(0) returns the instance of REPEATED_LAYER that is on the top of the network while get_repeated_layer(num_repetitions()-1) returns the instance of REPEATED_LAYER that is stacked immediately on top of SUBNET. !*/ repeated_layer_type& get_repeated_layer ( size_t i ); /*! requires - i < num_repetitions() ensures - returns a reference to the i-th instance of REPEATED_LAYER. For example, get_repeated_layer(0) returns the instance of REPEATED_LAYER that is on the top of the network while get_repeated_layer(num_repetitions()-1) returns the instance of REPEATED_LAYER that is stacked immediately on top of SUBNET. !*/ const subnet_type& subnet( ) const; /*! ensures - returns the SUBNET base network that repeat sits on top of. If you want to access the REPEATED_LAYER components then you must use get_repeated_layer(). !*/ subnet_type& subnet( ); /*! ensures - returns the SUBNET base network that repeat sits on top of. If you want to access the REPEATED_LAYER components then you must use get_repeated_layer(). !*/ const input_layer_type& input_layer( ) const; /*! ensures - returns the very first layer in *this network. It's equivalent to calling subnet() recursively until you get to the first layer. This means it will return the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined in input_abstract.h !*/ input_layer_type& input_layer( ); /*! ensures - returns the very first layer in *this network. It's equivalent to calling subnet() recursively until you get to the first layer. This means it will return the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined in input_abstract.h !*/ }; template < size_t num, template class T, typename U > std::ostream& operator<<(std::ostream& out, const repeat& item); /*! prints the network architecture to the given output stream. !*/ template < size_t num, template class T, typename U > void serialize(const repeat& item, std::ostream& out); template < size_t num, template class T, typename U > void deserialize(repeat& item, std::istream& in); /*! provides serialization support !*/ // ---------------------------------------------------------------------------------------- template < unsigned long ID, typename SUBNET > class add_tag_layer { /*! REQUIREMENTS ON SUBNET - One of the following must be true: - SUBNET implements the EXAMPLE_INPUT_LAYER interface defined in input_abstract.h. - SUBNET is an add_layer object. - SUBNET is an add_tag_layer object. - SUBNET is an add_skip_layer object. - SUBNET is a repeat object. WHAT THIS OBJECT REPRESENTS This object adds a new layer to a deep neural network. However, this layer simply performs the identity transform. This means it is a no-op and its presence does not change the behavior of the network. It exists solely to be used by add_skip_layer to reference a particular part of a network. Also, this object provides an interface identical to the one defined by the add_layer object. !*/ }; template std::ostream& operator<<(std::ostream& out, const add_tag_layer& item); /*! prints the network architecture to the given output stream. !*/ template void serialize(const add_tag_layer& item, std::ostream& out); template void deserialize(add_tag_layer& item, std::istream& in); /*! provides serialization support !*/ template using tag1 = add_tag_layer< 1, SUBNET>; template using tag2 = add_tag_layer< 2, SUBNET>; template using tag3 = add_tag_layer< 3, SUBNET>; template using tag4 = add_tag_layer< 4, SUBNET>; template using tag5 = add_tag_layer< 5, SUBNET>; template using tag6 = add_tag_layer< 6, SUBNET>; template using tag7 = add_tag_layer< 7, SUBNET>; template using tag8 = add_tag_layer< 8, SUBNET>; template using tag9 = add_tag_layer< 9, SUBNET>; template using tag10 = add_tag_layer<10, SUBNET>; template class tag> struct tag_id { /*! REQUIREMENTS ON tag Tag should be an add_tag_layer template such as tag1, tag2, etc. WHAT THIS OBJECT REPRESENTS This is a tool for finding the numeric ID of a tag layer. For example, tag_id::id == 3. !*/ const static unsigned long id; }; // ---------------------------------------------------------------------------------------- template < template class TAG_TYPE, typename SUBNET > class add_skip_layer { /*! REQUIREMENTS ON SUBNET - One of the following must be true: - SUBNET is an add_layer object. - SUBNET is an add_tag_layer object. - SUBNET is an add_skip_layer object. - SUBNET is a repeat object. WHAT THIS OBJECT REPRESENTS This object adds a new layer to a deep neural network which draws its inputs from layer(subnet()) and performs the identity transform. Also, this object provides an interface identical to the one defined by the add_layer object. !*/ }; template class T, typename U> std::ostream& operator<<(std::ostream& out, const add_skip_layer& item); /*! prints the network architecture to the given output stream. !*/ template class T, typename U> void serialize(const add_skip_layer& item, std::ostream& out); template class T, typename U> void deserialize(add_skip_layer& item, std::istream& in); /*! provides serialization support !*/ template using skip1 = add_skip_layer< tag1, SUBNET>; template using skip2 = add_skip_layer< tag2, SUBNET>; template using skip3 = add_skip_layer< tag3, SUBNET>; template using skip4 = add_skip_layer< tag4, SUBNET>; template using skip5 = add_skip_layer< tag5, SUBNET>; template using skip6 = add_skip_layer< tag6, SUBNET>; template using skip7 = add_skip_layer< tag7, SUBNET>; template using skip8 = add_skip_layer< tag8, SUBNET>; template using skip9 = add_skip_layer< tag9, SUBNET>; template using skip10 = add_skip_layer; // ---------------------------------------------------------------------------------------- template < unsigned int i, typename net_type > auto& layer ( net_type& n ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or add_tag_layer. - i < net_type::num_layers ensures - This function allows you to access any layer in a network by its layer index i. Therefore, it will walk i steps down the network and return the layer object there. Since networks can be big, the best way to find layer index numbers is to print a network to the screen since the print out will include indexes for each layer. - In general, this function chains together i calls to n.subnet() and returns the result. So for example: - if (i == 0) - returns n - else if (i == 1) - returns n.subnet() - else if (i == 2) - returns n.subnet().subnet() - else if (i == 3) - returns n.subnet().subnet().subnet() - else - etc. Except that when it hits a repeat layer it recurses into the repeated layers contained inside. That is, if the layer index indicates a layer in a repeat object this function will make the appropriate call to get_repeated_layer() and do the right thing. !*/ template < template class Match, typename net_type > auto& layer ( net_type& n ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or add_tag_layer. ensures - returns the first layer in n that is of type Match. E.g. if net_type is fc>>> then calling layer(n) would return layer<1>(n), that is, a reference to the relu layer. !*/ template < template class Match, unsigned int i, typename net_type > auto& layer ( net_type& n ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or add_tag_layer. ensures - returns layer(layer(n)) !*/ // ---------------------------------------------------------------------------------------- template auto& input_layer ( net_type& net ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, repeat, or add_tag_layer. ensures - returns the input layer of the given network object. This is the same as just calling net.input_layer(). !*/ // ---------------------------------------------------------------------------------------- template < typename net_type, typename visitor > void visit_layer_parameters( net_type& net, visitor v ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or add_tag_layer. - v is a function object with a signature equivalent to: v(size_t idx, tensor& t) or: v(tensor& t) ensures - Loops over all the computational layers (i.e. layers with parameters, as opposed to loss, tag, or input layers) in net and passes their parameters to v(). To be specific, this function essentially performs the following: size_t computational_layer_idx = 0; for (size_t i = 0; i < net_type::num_layers; ++i) { if (layer(net) is a computational layer) { v(computational_layer_idx, layer(net).layer_details().get_layer_params()); ++computational_layer_idx; } } - When v() is called, the first argument is always < net_type::num_computational_layers. !*/ // ---------------------------------------------------------------------------------------- template < typename net_type, typename visitor > void visit_layer_parameter_gradients( net_type& net, visitor v ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or add_tag_layer. - v is a function object with a signature equivalent to: v(size_t idx, tensor& t) or: v(tensor& t) ensures - Loops over all the computational layers (i.e. layers with parameters, as opposed to loss, tag, or input layers) in net and passes their parameter gradients to v(). To be specific, this function essentially performs the following: size_t computational_layer_idx = 0; for (size_t i = 0; i < net_type::num_layers; ++i) { if (layer(net) is a computational layer) { v(computational_layer_idx, layer(net).get_parameter_gradient()); ++computational_layer_idx; } } - When v() is called, the first argument is always < net_type::num_computational_layers. !*/ // ---------------------------------------------------------------------------------------- template < typename net_type, typename visitor > void visit_layers( net_type& net, visitor v ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or add_tag_layer. - v is a function object with a signature equivalent to: v(size_t idx, any_net_type& t) or: v(any_net_type& t) That is, it takes an optional size_t and then any of the network types such as add_layer, add_loss_layer, etc. ensures - Loops over all the layers in net and calls v() on them. To be specific, this function essentially performs the following: for (size_t i = 0; i < net_type::num_layers; ++i) v(i, layer(net)); !*/ template < typename net_type, typename visitor > void visit_layers_backwards( net_type& net, visitor v ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or add_tag_layer. - v is a function object with a signature equivalent to: v(size_t idx, any_net_type& t) or: v(any_net_type& t) That is, it takes an optional size_t and then any of the network types such as add_layer, add_loss_layer, etc. ensures - Loops over all the layers in net and calls v() on them. The loop happens in the reverse order of visit_layers(). To be specific, this function essentially performs the following: for (size_t i = net_type::num_layers; i != 0; --i) v(i-1, layer(net)); !*/ // ---------------------------------------------------------------------------------------- template < typename net_type, typename visitor > void visit_computational_layers( net_type& net, visitor v ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or add_tag_layer. - v is a function object with a signature equivalent to: v(size_t idx, any_computational_layer& t) or: v(any_computational_layer& t) That is, it takes an optional size_t and then any of the computational layers. E.g. one of the layer types defined in dlib/dnn/layers_abstract.h like fc_ or conv_. ensures - Loops over all the computational layers in net and calls v() on them. To be specific, this function essentially performs the following: for (size_t i = 0; i < net_type::num_layers; ++i) if (layer(net) is an add_layer type, i.e. it adds a computational layer) v(i, layer(net).layer_details()); !*/ template < size_t begin, size_t end, typename net_type, typename visitor > void visit_computational_layers_range( net_type& net, visitor v ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or add_tag_layer. - v is a function object with a signature equivalent to: v(size_t idx, any_computational_layer& t) or: v(any_computational_layer& t) That is, it takes an optional size_t and then any of the computational layers. E.g. one of the layer types defined in dlib/dnn/layers_abstract.h like fc_ or conv_. ensures - Loops over all the computational layers in the range [begin,end) in net and calls v() on them. To be specific, this function essentially performs the following: for (size_t i = begin; i < end; ++i) if (layer(net) is an add_layer type, i.e. it adds a computational layer) v(i, layer(net).layer_details()); !*/ // ---------------------------------------------------------------------------------------- template < size_t begin, size_t end, typename net_type, typename visitor > void visit_layers_range( net_type& net, visitor v ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or add_tag_layer. - v is a function object with a signature equivalent to: v(size_t idx, any_net_type& t) or: v(any_net_type& t) That is, it takes an optional size_t and then any of the network types such as add_layer, add_loss_layer, etc. - begin <= end <= net_type::num_layers ensures - Loops over the layers in the range [begin,end) in net and calls v() on them. To be specific, this function essentially performs the following: for (size_t i = begin; i < end; ++i) v(i, layer(net)); !*/ template < size_t begin, size_t end, typename net_type, typename visitor > void visit_layers_backwards_range( net_type& net, visitor v ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or add_tag_layer. - v is a function object with a signature equivalent to: v(size_t idx, any_net_type& t) or: v(any_net_type& t) That is, it takes an optional size_t and then any of the network types such as add_layer, add_loss_layer, etc. - begin <= end <= net_type::num_layers ensures - Loops over the layers in the range [begin,end) in net and calls v() on them. The loop happens in the reverse order of visit_layers_range(). To be specific, this function essentially performs the following: for (size_t i = end; i != begin; --i) v(i-1, layer(net)); !*/ // ---------------------------------------------------------------------------------------- template < unsigned long tag_id, typename net_type, typename visitor > void visit_layers_until_tag( net_type& net, visitor v ); /*! requires - net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or add_tag_layer. - v is a function object with a signature equivalent to: v(any_net_type& t) That is, it must take any of the network types such as add_layer, add_loss_layer, etc. ensures - Loops over all the layers in net beginning with layer<0>(net) and going until a tag layer with an ID of tag_id is encountered. To be specific, this function essentially performs the following: size_t i = 0; while(layer(net) isn't an add_tag_layer with ID == tag_id) { v(layer(net)); ++i; } v(layer(net)); // also visits the tag layer itself at the very end. !*/ // ---------------------------------------------------------------------------------------- struct layer_test_results { std::string log; bool was_good; operator bool() const { return was_good; } }; inline std::ostream& operator<< (std::ostream& out, const layer_test_results& item) { out << item.log; return out; } template < typename layer_details_type > layer_test_results test_layer ( layer_details_type l ); /*! ensures - Checks if l correctly implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined in layers_abstract.h. Importantly, it computes numerical approximations to the gradients and compares them to the outputs of the layer. - The results of the testing are returned. In particular, if the returned object is RESULT then we will have: - RESULT.was_good == false if and only if the layer failed the testing. - RESULT.log == a string describing why the testing failed if was_good==false. - Note that this function is only capable of checking layers that take arbitrary subnetworks as input. So if you have designed a layer that expects only a certain restricted type of subnetwork then you might get a compile or runtime error when you call this function. !*/ // ---------------------------------------------------------------------------------------- } #endif // DLIB_DNn_CORE_ABSTRACT_H_ ================================================ FILE: dlib/dnn/input.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNn_INPUT_H_ #define DLIB_DNn_INPUT_H_ #include "input_abstract.h" #include "../matrix.h" #include "../array2d.h" #include "../pixel.h" #include "../image_processing.h" #include #include #include "../cuda/tensor_tools.h" namespace dlib { // ---------------------------------------------------------------------------------------- template class input { const static bool always_false = sizeof(T)!=sizeof(T); static_assert(always_false, "Unsupported type given to input<>. input<> only supports " "dlib::matrix and dlib::array2d objects."); }; // ---------------------------------------------------------------------------------------- template class input_rgb_image_sized; class input_rgb_image_pair; class input_rgb_image { public: typedef matrix input_type; input_rgb_image ( ) : avg_red(122.782f), avg_green(117.001f), avg_blue(104.298f) { } input_rgb_image ( float avg_red_, float avg_green_, float avg_blue_ ) : avg_red(avg_red_), avg_green(avg_green_), avg_blue(avg_blue_) {} template inline input_rgb_image ( const input_rgb_image_sized& item ); inline input_rgb_image ( const input_rgb_image_pair& item ); float get_avg_red() const { return avg_red; } float get_avg_green() const { return avg_green; } float get_avg_blue() const { return avg_blue; } bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { DLIB_CASSERT(std::distance(ibegin,iend) > 0); const auto nr = ibegin->nr(); const auto nc = ibegin->nc(); // make sure all the input matrices have the same dimensions for (auto i = ibegin; i != iend; ++i) { DLIB_CASSERT(i->nr()==nr && i->nc()==nc, "\t input_rgb_image::to_tensor()" << "\n\t All matrices given to to_tensor() must have the same dimensions." << "\n\t nr: " << nr << "\n\t nc: " << nc << "\n\t i->nr(): " << i->nr() << "\n\t i->nc(): " << i->nc() ); } // initialize data to the right size to contain the stuff in the iterator range. data.set_size(std::distance(ibegin,iend), 3, nr, nc); const size_t offset = nr*nc; auto ptr = data.host(); for (auto i = ibegin; i != iend; ++i) { for (long r = 0; r < nr; ++r) { for (long c = 0; c < nc; ++c) { rgb_pixel temp = (*i)(r,c); auto p = ptr++; *p = (temp.red-avg_red)/256.0; p += offset; *p = (temp.green-avg_green)/256.0; p += offset; *p = (temp.blue-avg_blue)/256.0; p += offset; } } ptr += offset*(data.k()-1); } } friend void serialize(const input_rgb_image& item, std::ostream& out) { serialize("input_rgb_image", out); serialize(item.avg_red, out); serialize(item.avg_green, out); serialize(item.avg_blue, out); } friend void deserialize(input_rgb_image& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "input_rgb_image" && version != "input_rgb_image_sized" && version != "input_rgb_image_pair") throw serialization_error("Unexpected version found while deserializing dlib::input_rgb_image."); deserialize(item.avg_red, in); deserialize(item.avg_green, in); deserialize(item.avg_blue, in); // read and discard the sizes if this was really a sized input layer. if (version == "input_rgb_image_sized") { size_t nr, nc; deserialize(nr, in); deserialize(nc, in); } } friend std::ostream& operator<<(std::ostream& out, const input_rgb_image& item) { out << "input_rgb_image("<\n"; } private: float avg_red; float avg_green; float avg_blue; }; // ---------------------------------------------------------------------------------------- template class input_rgb_image_sized { public: static_assert(NR != 0 && NC != 0, "The input image can't be empty."); typedef matrix input_type; input_rgb_image_sized ( ) : avg_red(122.782), avg_green(117.001), avg_blue(104.298) { } input_rgb_image_sized ( const input_rgb_image& item ) : avg_red(item.get_avg_red()), avg_green(item.get_avg_green()), avg_blue(item.get_avg_blue()) {} input_rgb_image_sized ( float avg_red_, float avg_green_, float avg_blue_ ) : avg_red(avg_red_), avg_green(avg_green_), avg_blue(avg_blue_) {} float get_avg_red() const { return avg_red; } float get_avg_green() const { return avg_green; } float get_avg_blue() const { return avg_blue; } bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { DLIB_CASSERT(std::distance(ibegin,iend) > 0); // make sure all input images have the correct size for (auto i = ibegin; i != iend; ++i) { DLIB_CASSERT(i->nr()==NR && i->nc()==NC, "\t input_rgb_image_sized::to_tensor()" << "\n\t All input images must have "<nr()<<" rows and "<nc()<<" columns." ); } // initialize data to the right size to contain the stuff in the iterator range. data.set_size(std::distance(ibegin,iend), 3, NR, NC); const size_t offset = NR*NC; auto ptr = data.host(); for (auto i = ibegin; i != iend; ++i) { for (size_t r = 0; r < NR; ++r) { for (size_t c = 0; c < NC; ++c) { rgb_pixel temp = (*i)(r,c); auto p = ptr++; *p = (temp.red-avg_red)/256.0; p += offset; *p = (temp.green-avg_green)/256.0; p += offset; *p = (temp.blue-avg_blue)/256.0; p += offset; } } ptr += offset*(data.k()-1); } } friend void serialize(const input_rgb_image_sized& item, std::ostream& out) { serialize("input_rgb_image_sized", out); serialize(item.avg_red, out); serialize(item.avg_green, out); serialize(item.avg_blue, out); serialize(NR, out); serialize(NC, out); } friend void deserialize(input_rgb_image_sized& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "input_rgb_image_sized") throw serialization_error("Unexpected version found while deserializing dlib::input_rgb_image_sized."); deserialize(item.avg_red, in); deserialize(item.avg_green, in); deserialize(item.avg_blue, in); size_t nr, nc; deserialize(nr, in); deserialize(nc, in); if (nr != NR || nc != NC) { std::ostringstream sout; sout << "Wrong image dimensions found while deserializing dlib::input_rgb_image_sized.\n"; sout << "Expected "<\n"; } private: float avg_red; float avg_green; float avg_blue; }; // ---------------------------------------------------------------------------------------- template input_rgb_image:: input_rgb_image ( const input_rgb_image_sized& item ) : avg_red(item.get_avg_red()), avg_green(item.get_avg_green()), avg_blue(item.get_avg_blue()) {} // ---------------------------------------------------------------------------------------- class input_rgb_image_pair { public: typedef std::pair, matrix> input_type; input_rgb_image_pair ( ) : avg_red(122.782), avg_green(117.001), avg_blue(104.298) { } input_rgb_image_pair ( float avg_red, float avg_green, float avg_blue ) : avg_red(avg_red), avg_green(avg_green), avg_blue(avg_blue) {} inline input_rgb_image_pair ( const input_rgb_image& item ) : avg_red(item.get_avg_red()), avg_green(item.get_avg_green()), avg_blue(item.get_avg_blue()) {} template inline input_rgb_image_pair ( const input_rgb_image_sized& item ) : avg_red(item.get_avg_red()), avg_green(item.get_avg_green()), avg_blue(item.get_avg_blue()) {} float get_avg_red() const { return avg_red; } float get_avg_green() const { return avg_green; } float get_avg_blue() const { return avg_blue; } bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { DLIB_CASSERT(std::distance(ibegin, iend) > 0); const auto nr = ibegin->first.nr(); const auto nc = ibegin->first.nc(); // make sure all the input matrices have the same dimensions for (auto i = ibegin; i != iend; ++i) { DLIB_CASSERT(i->first.nr() == nr && i->first.nc()==nc && i->second.nr() == nr && i->second.nc() == nc, "\t input_rgb_image_pair::to_tensor()" << "\n\t All matrices given to to_tensor() must have the same dimensions." << "\n\t nr: " << nr << "\n\t nc: " << nc << "\n\t i->first.nr(): " << i->first.nr() << "\n\t i->first.nc(): " << i->first.nc() << "\n\t i->second.nr(): " << i->second.nr() << "\n\t i->second.nc(): " << i->second.nc() ); } // initialize data to the right size to contain the stuff in the iterator range. data.set_size(2 * std::distance(ibegin, iend), 3, nr, nc); const size_t offset = nr * nc; const size_t offset2 = data.size() / 2; auto ptr = data.host(); for (auto i = ibegin; i != iend; ++i) { for (long r = 0; r < nr; ++r) { for (long c = 0; c < nc; ++c) { rgb_pixel temp_first = i->first(r, c); rgb_pixel temp_second = i->second(r, c); auto p = ptr++; *p = (temp_first.red - avg_red) / 256.0; *(p + offset2) = (temp_second.red - avg_red) / 256.0; p += offset; *p = (temp_first.green - avg_green) / 256.0; *(p + offset2) = (temp_second.green - avg_green) / 256.0; p += offset; *p = (temp_first.blue - avg_blue) / 256.0; *(p + offset2) = (temp_second.blue - avg_blue) / 256.0; p += offset; } } ptr += offset * (data.k() - 1); } } friend void serialize(const input_rgb_image_pair& item, std::ostream& out) { serialize("input_rgb_image_pair", out); serialize(item.avg_red, out); serialize(item.avg_green, out); serialize(item.avg_blue, out); } friend void deserialize(input_rgb_image_pair& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "input_rgb_image_pair" && version != "input_rgb_image" && version != "input_rgb_image_sized") throw serialization_error("Unexpected version found while deserializing dlib::input_rgb_image_pair."); deserialize(item.avg_red, in); deserialize(item.avg_green, in); deserialize(item.avg_blue, in); // read and discard the sizes if this was really a sized input layer. if (version == "input_rgb_image_sized") { size_t nr, nc; deserialize(nr, in); deserialize(nc, in); } } friend std::ostream& operator<<(std::ostream& out, const input_rgb_image_pair& item) { out << "input_rgb_image_pair("<< item.avg_red<<","<\n"; } private: float avg_red; float avg_green; float avg_blue; }; // ---------------------------------------------------------------------------------------- input_rgb_image:: input_rgb_image ( const input_rgb_image_pair& item ) : avg_red(item.get_avg_red()), avg_green(item.get_avg_green()), avg_blue(item.get_avg_blue()) {} // ---------------------------------------------------------------------------------------- template class input> { public: typedef matrix input_type; input() {} template input(const input>&) {} bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { DLIB_CASSERT(std::distance(ibegin,iend) > 0); const auto nr = ibegin->nr(); const auto nc = ibegin->nc(); // make sure all the input matrices have the same dimensions for (auto i = ibegin; i != iend; ++i) { DLIB_CASSERT(i->nr()==nr && i->nc()==nc, "\t input::to_tensor()" << "\n\t All matrices given to to_tensor() must have the same dimensions." << "\n\t nr: " << nr << "\n\t nc: " << nc << "\n\t i->nr(): " << i->nr() << "\n\t i->nc(): " << i->nc() ); } // initialize data to the right size to contain the stuff in the iterator range. data.set_size(std::distance(ibegin,iend), pixel_traits::num, nr, nc); typedef typename pixel_traits::basic_pixel_type bptype; const size_t offset = nr*nc; auto ptr = data.host(); for (auto i = ibegin; i != iend; ++i) { for (long r = 0; r < nr; ++r) { for (long c = 0; c < nc; ++c) { auto temp = pixel_to_vector((*i)(r,c)); auto p = ptr++; for (long j = 0; j < temp.size(); ++j) { if (is_same_type::value) *p = temp(j)/256.0; else *p = temp(j); p += offset; } } } ptr += offset*(data.k()-1); } } friend void serialize(const input& /*item*/, std::ostream& out) { serialize("input", out); } friend void deserialize(input& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "input") throw serialization_error("Unexpected version found while deserializing dlib::input."); } friend std::ostream& operator<<(std::ostream& out, const input& /*item*/) { out << "input"; return out; } friend void to_xml(const input& /*item*/, std::ostream& out) { out << "\n"; } }; // ---------------------------------------------------------------------------------------- template class input,K>> { public: typedef std::array,K> input_type; input() {} input(const input&) {} bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { DLIB_CASSERT(std::distance(ibegin,iend) > 0); DLIB_CASSERT(ibegin->size() != 0, "When using std::array inputs you can't give 0 sized arrays."); const auto nr = (*ibegin)[0].nr(); const auto nc = (*ibegin)[0].nc(); // make sure all the input matrices have the same dimensions for (auto i = ibegin; i != iend; ++i) { for (size_t k = 0; k < K; ++k) { const auto& arr = *i; DLIB_CASSERT(arr[k].nr()==nr && arr[k].nc()==nc, "\t input::to_tensor()" << "\n\t When using std::array as input, all matrices in a batch must have the same dimensions." << "\n\t nr: " << nr << "\n\t nc: " << nc << "\n\t k: " << k << "\n\t arr[k].nr(): " << arr[k].nr() << "\n\t arr[k].nc(): " << arr[k].nc() ); } } // initialize data to the right size to contain the stuff in the iterator range. data.set_size(std::distance(ibegin,iend), K, nr, nc); auto ptr = data.host(); for (auto i = ibegin; i != iend; ++i) { for (size_t k = 0; k < K; ++k) { for (long r = 0; r < nr; ++r) { for (long c = 0; c < nc; ++c) { if (is_same_type::value) *ptr++ = (*i)[k](r,c)/256.0; else *ptr++ = (*i)[k](r,c); } } } } } friend void serialize(const input& /*item*/, std::ostream& out) { serialize("input>", out); } friend void deserialize(input& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "input>") throw serialization_error("Unexpected version found while deserializing dlib::input>."); } friend std::ostream& operator<<(std::ostream& out, const input& /*item*/) { out << "input>"; return out; } friend void to_xml(const input& /*item*/, std::ostream& out) { out << "\n"; } }; // ---------------------------------------------------------------------------------------- template class input> { public: typedef array2d input_type; input() {} input(const input&) {} template input(const input>&) {} bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { DLIB_CASSERT(std::distance(ibegin,iend) > 0); const auto nr = ibegin->nr(); const auto nc = ibegin->nc(); // make sure all the input matrices have the same dimensions for (auto i = ibegin; i != iend; ++i) { DLIB_CASSERT(i->nr()==nr && i->nc()==nc, "\t input::to_tensor()" << "\n\t All array2d objects given to to_tensor() must have the same dimensions." << "\n\t nr: " << nr << "\n\t nc: " << nc << "\n\t i->nr(): " << i->nr() << "\n\t i->nc(): " << i->nc() ); } // initialize data to the right size to contain the stuff in the iterator range. data.set_size(std::distance(ibegin,iend), pixel_traits::num, nr, nc); typedef typename pixel_traits::basic_pixel_type bptype; const size_t offset = nr*nc; auto ptr = data.host(); for (auto i = ibegin; i != iend; ++i) { for (long r = 0; r < nr; ++r) { for (long c = 0; c < nc; ++c) { auto temp = pixel_to_vector((*i)[r][c]); auto p = ptr++; for (long j = 0; j < temp.size(); ++j) { if (is_same_type::value) *p = temp(j)/256.0; else *p = temp(j); p += offset; } } } ptr += offset*(data.k()-1); } } friend void serialize(const input&, std::ostream& out) { serialize("input", out); } friend void deserialize(input&, std::istream& in) { std::string version; deserialize(version, in); if (version != "input") throw serialization_error("Unexpected version found while deserializing dlib::input."); } friend std::ostream& operator<<(std::ostream& out, const input&) { out << "input"; return out; } friend void to_xml(const input&, std::ostream& out) { out << "\n"; } }; // ---------------------------------------------------------------------------------------- namespace detail { template class input_image_pyramid { public: virtual ~input_image_pyramid() = 0; typedef PYRAMID_TYPE pyramid_type; unsigned long get_pyramid_padding() const { return pyramid_padding; } void set_pyramid_padding(unsigned long value) { pyramid_padding = value; } unsigned long get_pyramid_outer_padding() const { return pyramid_outer_padding; } void set_pyramid_outer_padding(unsigned long value) { pyramid_outer_padding = value; } bool image_contained_point( const tensor& data, const point& p ) const { auto&& rects = any_cast>(data.annotation()); DLIB_CASSERT(rects.size() > 0); return rects[0].contains(p + rects[0].tl_corner()); } drectangle tensor_space_to_image_space( const tensor& data, drectangle r ) const { auto&& rects = any_cast>(data.annotation()); return tiled_pyramid_to_image(rects, r); } drectangle image_space_to_tensor_space ( const tensor& data, double scale, drectangle r ) const { DLIB_CASSERT(0 < scale && scale <= 1, "scale: " << scale); auto&& rects = any_cast>(data.annotation()); return image_to_tiled_pyramid(rects, scale, r); } protected: template void to_tensor_init ( forward_iterator ibegin, forward_iterator iend, resizable_tensor &data, unsigned int k ) const { DLIB_CASSERT(std::distance(ibegin, iend) > 0); auto nr = ibegin->nr(); auto nc = ibegin->nc(); // make sure all the input matrices have the same dimensions for (auto i = ibegin; i != iend; ++i) { DLIB_CASSERT(i->nr() == nr && i->nc() == nc, "\t input_grayscale_image_pyramid::to_tensor()" << "\n\t All matrices given to to_tensor() must have the same dimensions." << "\n\t nr: " << nr << "\n\t nc: " << nc << "\n\t i->nr(): " << i->nr() << "\n\t i->nc(): " << i->nc() ); } long NR, NC; pyramid_type pyr; auto& rects = data.annotation().get>(); impl::compute_tiled_image_pyramid_details(pyr, nr, nc, pyramid_padding, pyramid_outer_padding, rects, NR, NC); // initialize data to the right size to contain the stuff in the iterator range. data.set_size(std::distance(ibegin, iend), k, NR, NC); // We need to zero the image before doing the pyramid, since the pyramid // creation code doesn't write to all parts of the image. We also take // care to avoid triggering any device to hosts copies. auto ptr = data.host_write_only(); for (size_t i = 0; i < data.size(); ++i) ptr[i] = 0; } // now build the image pyramid into data. This does the same thing as // standard create_tiled_pyramid(), except we use the GPU if one is available. void create_tiled_pyramid ( const std::vector& rects, resizable_tensor& data ) const { for (size_t i = 1; i < rects.size(); ++i) { alias_tensor src(data.num_samples(), data.k(), rects[i - 1].height(), rects[i - 1].width()); alias_tensor dest(data.num_samples(), data.k(), rects[i].height(), rects[i].width()); auto asrc = src(data, data.nc() * rects[i - 1].top() + rects[i - 1].left()); auto adest = dest(data, data.nc() * rects[i].top() + rects[i].left()); tt::resize_bilinear(adest, data.nc(), data.nr() * data.nc(), asrc, data.nc(), data.nr() * data.nc()); } } unsigned long pyramid_padding = 10; unsigned long pyramid_outer_padding = 11; }; template input_image_pyramid::~input_image_pyramid() {} } // ---------------------------------------------------------------------------------------- template class input_grayscale_image_pyramid : public detail::input_image_pyramid { public: typedef matrix input_type; typedef PYRAMID_TYPE pyramid_type; template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { this->to_tensor_init(ibegin, iend, data, 1); const auto rects = data.annotation().get>(); if (rects.size() == 0) return; // copy the first raw image into the top part of the tiled pyramid. We need to // do this for each of the input images/samples in the tensor. auto ptr = data.host_write_only(); for (auto i = ibegin; i != iend; ++i) { auto& img = *i; ptr += rects[0].top()*data.nc(); for (long r = 0; r < img.nr(); ++r) { auto p = ptr+rects[0].left(); for (long c = 0; c < img.nc(); ++c) p[c] = (img(r,c))/256.0; ptr += data.nc(); } ptr += data.nc()*(data.nr()-rects[0].bottom()-1); } this->create_tiled_pyramid(rects, data); } friend void serialize(const input_grayscale_image_pyramid& item, std::ostream& out) { serialize("input_grayscale_image_pyramid", out); serialize(item.pyramid_padding, out); serialize(item.pyramid_outer_padding, out); } friend void deserialize(input_grayscale_image_pyramid& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "input_grayscale_image_pyramid") throw serialization_error("Unexpected version found while deserializing dlib::input_grayscale_image_pyramid."); deserialize(item.pyramid_padding, in); deserialize(item.pyramid_outer_padding, in); } friend std::ostream& operator<<(std::ostream& out, const input_grayscale_image_pyramid& item) { out << "input_grayscale_image_pyramid()"; out << " pyramid_padding="<\n"; } }; // ---------------------------------------------------------------------------------------- template class input_rgb_image_pyramid : public detail::input_image_pyramid { public: typedef matrix input_type; typedef PYRAMID_TYPE pyramid_type; input_rgb_image_pyramid ( ) : avg_red(122.782), avg_green(117.001), avg_blue(104.298) { } input_rgb_image_pyramid ( float avg_red_, float avg_green_, float avg_blue_ ) : avg_red(avg_red_), avg_green(avg_green_), avg_blue(avg_blue_) {} float get_avg_red() const { return avg_red; } float get_avg_green() const { return avg_green; } float get_avg_blue() const { return avg_blue; } template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { this->to_tensor_init(ibegin, iend, data, 3); const auto rects = data.annotation().get>(); if (rects.size() == 0) return; // copy the first raw image into the top part of the tiled pyramid. We need to // do this for each of the input images/samples in the tensor. auto ptr = data.host_write_only(); for (auto i = ibegin; i != iend; ++i) { auto& img = *i; ptr += rects[0].top()*data.nc(); for (long r = 0; r < img.nr(); ++r) { auto p = ptr+rects[0].left(); for (long c = 0; c < img.nc(); ++c) p[c] = (img(r,c).red-avg_red)/256.0; ptr += data.nc(); } ptr += data.nc()*(data.nr()-rects[0].bottom()-1); ptr += rects[0].top()*data.nc(); for (long r = 0; r < img.nr(); ++r) { auto p = ptr+rects[0].left(); for (long c = 0; c < img.nc(); ++c) p[c] = (img(r,c).green-avg_green)/256.0; ptr += data.nc(); } ptr += data.nc()*(data.nr()-rects[0].bottom()-1); ptr += rects[0].top()*data.nc(); for (long r = 0; r < img.nr(); ++r) { auto p = ptr+rects[0].left(); for (long c = 0; c < img.nc(); ++c) p[c] = (img(r,c).blue-avg_blue)/256.0; ptr += data.nc(); } ptr += data.nc()*(data.nr()-rects[0].bottom()-1); } this->create_tiled_pyramid(rects, data); } friend void serialize(const input_rgb_image_pyramid& item, std::ostream& out) { serialize("input_rgb_image_pyramid2", out); serialize(item.avg_red, out); serialize(item.avg_green, out); serialize(item.avg_blue, out); serialize(item.pyramid_padding, out); serialize(item.pyramid_outer_padding, out); } friend void deserialize(input_rgb_image_pyramid& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "input_rgb_image_pyramid" && version != "input_rgb_image_pyramid2") throw serialization_error("Unexpected version found while deserializing dlib::input_rgb_image_pyramid."); deserialize(item.avg_red, in); deserialize(item.avg_green, in); deserialize(item.avg_blue, in); if (version == "input_rgb_image_pyramid2") { deserialize(item.pyramid_padding, in); deserialize(item.pyramid_outer_padding, in); } else { item.pyramid_padding = 10; item.pyramid_outer_padding = 11; } } friend std::ostream& operator<<(std::ostream& out, const input_rgb_image_pyramid& item) { out << "input_rgb_image_pyramid("<\n"; } private: float avg_red; float avg_green; float avg_blue; }; // ---------------------------------------------------------------------------------------- class input_tensor { public: typedef tensor input_type; input_tensor() {} input_tensor(const input_tensor&) {} template void to_tensor( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const { DLIB_CASSERT(std::distance(ibegin, iend) > 0); const auto k = ibegin->k(); const auto nr = ibegin->nr(); const auto nc = ibegin->nc(); // make sure all the input tensors have the same dimensions for (auto i = ibegin; i != iend; ++i) { DLIB_CASSERT(i->k() == k && i->nr() == nr && i->nc() == nc, "\t input_tensor::to_tensor()" << "\n\t All tensor objects given to to_tensor() must have the same dimensions." << "\n\t k: " << k << "\n\t nr: " << nr << "\n\t nc: " << nc << "\n\t i->k(): " << i->k() << "\n\t i->nr(): " << i->nr() << "\n\t i->nc(): " << i->nc() ); } const auto num_samples = count_samples(ibegin, iend); // initialize data to the right size to contain the stuff in the iterator range. data.set_size(num_samples, k, nr, nc); const size_t stride = k * nr * nc; size_t offset = 0; for (auto i = ibegin; i != iend; ++i) { alias_tensor slice(i->num_samples(), k, nr, nc); memcpy(slice(data, offset), *i); offset += slice.num_samples() * stride; } } friend void serialize(const input_tensor&, std::ostream& out) { serialize("input_tensor", out); } friend void deserialize(input_tensor&, std::istream& in) { std::string version; deserialize(version, in); if (version != "input_tensor") throw serialization_error("Unexpected version found while deserializing dlib::input_tensor."); } friend std::ostream& operator<<(std::ostream& out, const input_tensor&) { out << "input_tensor"; return out; } friend void to_xml(const input_tensor&, std::ostream& out) { out << "\n"; } private: template long long count_samples( forward_iterator ibegin, forward_iterator iend ) const { return std::accumulate(ibegin, iend, 0, [](long long a, const auto& b) { return a + b.num_samples(); }); } }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_DNn_INPUT_H_ ================================================ FILE: dlib/dnn/input_abstract.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_DNn_INPUT_ABSTRACT_H_ #ifdef DLIB_DNn_INPUT_ABSTRACT_H_ #include "../matrix.h" #include "../pixel.h" namespace dlib { // ---------------------------------------------------------------------------------------- class EXAMPLE_INPUT_LAYER { /*! WHAT THIS OBJECT REPRESENTS Each deep neural network model in dlib begins with an input layer. The job of the input layer is to convert an input_type into a tensor. Nothing more and nothing less. Note that there is no dlib::EXAMPLE_INPUT_LAYER type. It is shown here purely to document the interface that an input layer object must implement. If you are using some kind of image or matrix object as your input_type then you can use the provided dlib::input layer defined below. Otherwise, you need to define your own custom input layer. THREAD SAFETY to_tensor() must be thread safe. That is, multiple threads must be able to make calls to to_tensor() on a single instance of this object at the same time. !*/ public: EXAMPLE_INPUT_LAYER( ); /*! ensures - Default constructs this object. This function is not required to do anything in particular but it must exist, that is, it is required that layer objects be default constructable. !*/ EXAMPLE_INPUT_LAYER ( const EXAMPLE_INPUT_LAYER& item ); /*! ensures - EXAMPLE_INPUT_LAYER objects are copy constructable !*/ EXAMPLE_INPUT_LAYER( const some_other_input_layer_type& item ); /*! ensures - Constructs this object from item. This form of constructor is optional but it allows you to provide a conversion from one input layer type to another. For example, the following code is valid only if my_input_layer2 can be constructed from my_input_layer1: relu>>> my_dnn1; relu>>> my_dnn2(my_dnn1); This kind of pattern is useful if you want to use one type of input layer during training but a different type of layer during testing since it allows you to easily convert between related deep neural network types. !*/ typedef whatever_type_to_tensor_expects input_type; template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const; /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 ensures - Converts the iterator range into a tensor and stores it into #data. - #data.num_samples()%distance(ibegin,iend) == 0. Normally you would have #data.num_samples() == distance(ibegin,iend) but you can also expand the output by some integer factor so long as the loss you use can deal with it correctly. - The data in the ith sample of #data corresponds to the input_type object *(ibegin+i/sample_expansion_factor). where sample_expansion_factor==#data.num_samples()/distance(ibegin,iend). !*/ }; std::ostream& operator<<(std::ostream& out, const EXAMPLE_INPUT_LAYER& item); /*! print a string describing this layer. !*/ void to_xml(const EXAMPLE_INPUT_LAYER& item, std::ostream& out); /*! This function is optional, but required if you want to print your networks with net_to_xml(). Therefore, to_xml() prints a layer as XML. !*/ void serialize(const EXAMPLE_INPUT_LAYER& item, std::ostream& out); void deserialize(EXAMPLE_INPUT_LAYER& item, std::istream& in); /*! provides serialization support !*/ // ---------------------------------------------------------------------------------------- template < typename T > class input { /*! REQUIREMENTS ON T One of the following must be true: - T is a matrix or array2d object and it must contain some kind of pixel type. I.e. pixel_traits must be defined. - T is a std::array> where U is any built in scalar type like float, double, or unsigned char. WHAT THIS OBJECT REPRESENTS This is a basic input layer that simply copies images into a tensor. !*/ public: typedef T input_type; template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const; /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 - The input range should contain image objects that all have the same dimensions. ensures - Converts the iterator range into a tensor and stores it into #data. In particular, if the input images have R rows, C columns, and K channels (where K is given by pixel_traits::num or std::array::size() if std::array inputs are used) then we will have: - #data.num_samples() == std::distance(ibegin,iend) - #data.nr() == R - #data.nc() == C - #data.k() == K For example, a matrix would turn into a tensor with 3 rows, 3 columns, and k()==1. Or a matrix would turn into a tensor with 4 rows, 5 columns, and k()==3 (since rgb_pixels have 3 channels). Or a std::array,5> would turn into a tensor with 3 rows and columns, and k()==5 channels. - If the input data contains pixels of type unsigned char, rgb_pixel, or other pixel types with a basic_pixel_type of unsigned char then each value written to the output tensor is first divided by 256.0 so that the resulting outputs are all in the range [0,1]. !*/ // Provided for compatibility with input_rgb_image_pyramid's interface bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } }; // ---------------------------------------------------------------------------------------- class input_rgb_image { /*! WHAT THIS OBJECT REPRESENTS This input layer works with RGB images of type matrix. It is very similar to the dlib::input layer except that it allows you to subtract the average color value from each color channel when converting an image to a tensor. !*/ public: typedef matrix input_type; input_rgb_image ( ); /*! ensures - #get_avg_red() == 122.782 - #get_avg_green() == 117.001 - #get_avg_blue() == 104.298 !*/ input_rgb_image ( float avg_red, float avg_green, float avg_blue ); /*! ensures - #get_avg_red() == avg_red - #get_avg_green() == avg_green - #get_avg_blue() == avg_blue !*/ float get_avg_red( ) const; /*! ensures - returns the value subtracted from the red color channel. !*/ float get_avg_green( ) const; /*! ensures - returns the value subtracted from the green color channel. !*/ float get_avg_blue( ) const; /*! ensures - returns the value subtracted from the blue color channel. !*/ template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const; /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 - The input range should contain images that all have the same dimensions. ensures - Converts the iterator range into a tensor and stores it into #data. In particular, if the input images have R rows, C columns then we will have: - #data.num_samples() == std::distance(ibegin,iend) - #data.nr() == R - #data.nc() == C - #data.k() == 3 Moreover, each color channel is normalized by having its average value subtracted (according to get_avg_red(), get_avg_green(), or get_avg_blue()) and then is divided by 256.0. !*/ // Provided for compatibility with input_rgb_image_pyramid's interface bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } }; // ---------------------------------------------------------------------------------------- template class input_rgb_image_sized { /*! WHAT THIS OBJECT REPRESENTS This layer has an interface and behavior identical to input_rgb_image except that it requires input images to have NR rows and NC columns. This is checked by a DLIB_CASSERT inside to_tensor(). You can also convert between input_rgb_image and input_rgb_image_sized by copy construction or assignment. !*/ }; // ---------------------------------------------------------------------------------------- class input_rgb_image_pair { /*! WHAT THIS OBJECT REPRESENTS This input layer works with std::pair of RGB images of type matrix. It is useful when you want to input image pairs that are related to each other, for instance, they are different distorted views of the same original image. It is mainly supposed to be used with unsupervised loss functions such as loss_barlow_twins_. You can also convert between input_rgb_image and input_rgb_image_pair by copy construction or assignment. !*/ public: typedef std::pair, matrix> input_type; input_rgb_image_pair ( ); /*! ensures - #get_avg_red() == 122.782 - #get_avg_green() == 117.001 - #get_avg_blue() == 104.298 !*/ input_rgb_image_pair ( float avg_red, float avg_green, float avg_blue ); /*! ensures - #get_avg_red() == avg_red - #get_avg_green() == avg_green - #get_avg_blue() == avg_blue !*/ inline input_rgb_image_pair ( const input_rgb_image& item ); /*! ensures - #get_avg_red() == item.get_avg_red() - #get_avg_green() == item.get_avg_green() - #get_avg_blue() == item.get_avg_blue() !*/ template inline input_rgb_image_pair ( const input_rgb_image_sized& item ); /*! ensures - #get_avg_red() == item.get_avg_red() - #get_avg_green() == item.get_avg_green() - #get_avg_blue() == item.get_avg_blue() !*/ float get_avg_red( ) const; /*! ensures - returns the value subtracted from the red color channel. !*/ float get_avg_green( ) const; /*! ensures - returns the value subtracted from the green color channel. !*/ float get_avg_blue( ) const; /*! ensures - returns the value subtracted from the blue color channel. !*/ void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const; /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 - The input range should contain images that all have the same dimensions. ensures - Converts the iterator range into a tensor and stores it into #data. In particular, if the input images have R rows, C columns then we will have: - #data.num_samples() == 2 * std::distance(ibegin,iend) - #data.nr() == R - #data.nc() == C - #data.k() == 3 Moreover, each color channel is normalized by having its average value subtracted (according to get_avg_red(), get_avg_green(), or get_avg_blue()) and then is divided by 256.0. Additionally, the first elements in each pair are placed in the first half of the batch, and the second elements in the second half. !*/ // Provided for compatibility with input_rgb_image_pyramid's interface bool image_contained_point ( const tensor& data, const point& p) const { return get_rect(data).contains(p); } drectangle tensor_space_to_image_space ( const tensor& /*data*/, drectangle r) const { return r; } drectangle image_space_to_tensor_space ( const tensor& /*data*/, double /*scale*/, drectangle r ) const { return r; } // ---------------------------------------------------------------------------------------- template < typename PYRAMID_TYPE > class input_grayscale_image_pyramid { /*! REQUIREMENTS ON PYRAMID_TYPE PYRAMID_TYPE must be an instance of the dlib::pyramid_down template. WHAT THIS OBJECT REPRESENTS This input layer works with gray scale images of type matrix. It is identical to input layer except that it outputs a tensor containing a tiled image pyramid of each input image rather than a simple copy of each image. The tiled image pyramid is created using create_tiled_pyramid(). !*/ public: typedef matrix input_type; typedef PYRAMID_TYPE pyramid_type; input_grayscale_image_pyramid ( ); /*! ensures - #get_pyramid_padding() == 10 - #get_pyramid_outer_padding() == 11 !*/ unsigned long get_pyramid_padding ( ) const; /*! ensures - When this object creates a pyramid it will call create_tiled_pyramid() and set create_tiled_pyramid's pyramid_padding parameter to get_pyramid_padding(). !*/ void set_pyramid_padding ( unsigned long value ); /*! ensures - #get_pyramid_padding() == value !*/ unsigned long get_pyramid_outer_padding ( ) const; /*! ensures - When this object creates a pyramid it will call create_tiled_pyramid() and set create_tiled_pyramid's pyramid_outer_padding parameter to get_pyramid_outer_padding(). !*/ void set_pyramid_outer_padding ( unsigned long value ); /*! ensures - #get_pyramid_outer_padding() == value !*/ template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const; /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 - The input range should contain images that all have the same dimensions. ensures - Converts the iterator range into a tensor and stores it into #data. In particular, we will have: - #data.num_samples() == std::distance(ibegin,iend) - #data.k() == 1 - Each sample in #data contains a tiled image pyramid of the corresponding input image. The tiled pyramid is created by create_tiled_pyramid(). Moreover, each pixel is normalized, dividing them by 256.0. !*/ bool image_contained_point ( const tensor& data, const point& p ) const; /*! requires - data is a tensor that was produced by this->to_tensor() ensures - Since data is a tensor that is built from a bunch of identically sized images, we can ask if those images were big enough to contain the point p. This function returns the answer to that question. !*/ drectangle image_space_to_tensor_space ( const tensor& data, double scale, drectangle r ) const; /*! requires - data is a tensor that was produced by this->to_tensor() - 0 < scale <= 1 ensures - This function maps from to_tensor()'s input image space to its output tensor space. Therefore, given that data is a tensor produced by to_tensor(), image_space_to_tensor_space() allows you to ask for the rectangle in data that corresponds to a rectangle in the original image space. Note that since the output tensor contains an image pyramid, there are multiple points in the output tensor that correspond to any input location. So you must also specify a scale so we know what level of the pyramid is needed. So given a rectangle r in an input image, you can ask, what rectangle in data corresponds to r when things are scale times smaller? That rectangle is returned by this function. - A scale of 1 means we don't move anywhere in the pyramid scale space relative to the input image while smaller values of scale mean we move down the pyramid. !*/ drectangle tensor_space_to_image_space ( const tensor& data, drectangle r ) const; /*! requires - data is a tensor that was produced by this->to_tensor() ensures - This function maps from to_tensor()'s output tensor space to its input image space. Therefore, given that data is a tensor produced by to_tensor(), tensor_space_to_image_space() allows you to ask for the rectangle in the input image that corresponds to a rectangle in data. - It should be noted that this function isn't always an inverse of image_space_to_tensor_space(). This is because you can ask image_space_to_tensor_space() for the coordinates of points outside the input image and they will be mapped to somewhere that doesn't have an inverse. But for points actually inside the input image this function performs an approximate inverse mapping. I.e. when image_contained_point(data,center(r))==true there is an approximate inverse. !*/ }; // ---------------------------------------------------------------------------------------- template < typename PYRAMID_TYPE > class input_rgb_image_pyramid { /*! REQUIREMENTS ON PYRAMID_TYPE PYRAMID_TYPE must be an instance of the dlib::pyramid_down template. WHAT THIS OBJECT REPRESENTS This input layer works with RGB images of type matrix. It is identical to input_rgb_image except that it outputs a tensor containing a tiled image pyramid of each input image rather than a simple copy of each image. The tiled image pyramid is created using create_tiled_pyramid(). !*/ public: typedef matrix input_type; typedef PYRAMID_TYPE pyramid_type; input_rgb_image_pyramid ( ); /*! ensures - #get_avg_red() == 122.782 - #get_avg_green() == 117.001 - #get_avg_blue() == 104.298 - #get_pyramid_padding() == 10 - #get_pyramid_outer_padding() == 11 !*/ input_rgb_image_pyramid ( float avg_red, float avg_green, float avg_blue ); /*! ensures - #get_avg_red() == avg_red - #get_avg_green() == avg_green - #get_avg_blue() == avg_blue - #get_pyramid_padding() == 10 - #get_pyramid_outer_padding() == 11 !*/ float get_avg_red( ) const; /*! ensures - returns the value subtracted from the red color channel. !*/ float get_avg_green( ) const; /*! ensures - returns the value subtracted from the green color channel. !*/ float get_avg_blue( ) const; /*! ensures - returns the value subtracted from the blue color channel. !*/ unsigned long get_pyramid_padding ( ) const; /*! ensures - When this object creates a pyramid it will call create_tiled_pyramid() and set create_tiled_pyramid's pyramid_padding parameter to get_pyramid_padding(). !*/ void set_pyramid_padding ( unsigned long value ); /*! ensures - #get_pyramid_padding() == value !*/ unsigned long get_pyramid_outer_padding ( ) const; /*! ensures - When this object creates a pyramid it will call create_tiled_pyramid() and set create_tiled_pyramid's pyramid_outer_padding parameter to get_pyramid_outer_padding(). !*/ void set_pyramid_outer_padding ( unsigned long value ); /*! ensures - #get_pyramid_outer_padding() == value !*/ template void to_tensor ( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const; /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 - The input range should contain images that all have the same dimensions. ensures - Converts the iterator range into a tensor and stores it into #data. In particular, we will have: - #data.num_samples() == std::distance(ibegin,iend) - #data.k() == 3 - Each sample in #data contains a tiled image pyramid of the corresponding input image. The tiled pyramid is created by create_tiled_pyramid(). Moreover, each color channel is normalized by having its average value subtracted (according to get_avg_red(), get_avg_green(), or get_avg_blue()) and then is divided by 256.0. !*/ bool image_contained_point ( const tensor& data, const point& p ) const; /*! requires - data is a tensor that was produced by this->to_tensor() ensures - Since data is a tensor that is built from a bunch of identically sized images, we can ask if those images were big enough to contain the point p. This function returns the answer to that question. !*/ drectangle image_space_to_tensor_space ( const tensor& data, double scale, drectangle r ) const; /*! requires - data is a tensor that was produced by this->to_tensor() - 0 < scale <= 1 ensures - This function maps from to_tensor()'s input image space to its output tensor space. Therefore, given that data is a tensor produced by to_tensor(), image_space_to_tensor_space() allows you to ask for the rectangle in data that corresponds to a rectangle in the original image space. Note that since the output tensor contains an image pyramid, there are multiple points in the output tensor that correspond to any input location. So you must also specify a scale so we know what level of the pyramid is needed. So given a rectangle r in an input image, you can ask, what rectangle in data corresponds to r when things are scale times smaller? That rectangle is returned by this function. - A scale of 1 means we don't move anywhere in the pyramid scale space relative to the input image while smaller values of scale mean we move down the pyramid. !*/ drectangle tensor_space_to_image_space ( const tensor& data, drectangle r ) const; /*! requires - data is a tensor that was produced by this->to_tensor() ensures - This function maps from to_tensor()'s output tensor space to its input image space. Therefore, given that data is a tensor produced by to_tensor(), tensor_space_to_image_space() allows you to ask for the rectangle in the input image that corresponds to a rectangle in data. - It should be noted that this function isn't always an inverse of image_space_to_tensor_space(). This is because you can ask image_space_to_tensor_space() for the coordinates of points outside the input image and they will be mapped to somewhere that doesn't have an inverse. But for points actually inside the input image this function performs an approximate inverse mapping. I.e. when image_contained_point(data,center(r))==true there is an approximate inverse. !*/ }; // ---------------------------------------------------------------------------------------- class input_tensor { /*! WHAT THIS OBJECT REPRESENTS This input layer works with dlib::tensor objects. It is very similar to the dlib::input layer except that it allows for concatenating data that already resides in GPU memory. !*/ public: typedef tensor input_type; input_tensor( ); /*! ensures - input_tensor objects are default constructable !*/ input_tensor( const input_tensor& item ); /*! ensures - input_tensor objects are copy constructable !*/ template void to_tensor( forward_iterator ibegin, forward_iterator iend, resizable_tensor& data ) const; /*! requires - [ibegin, iend) is an iterator range over input_type objects. - std::distance(ibegin,iend) > 0 - The input range should contain tensor objects that all have the same dimensions. ensures - Copies the iterator range into #data. In particular, if the input tensors have R rows, C columns, and K channels then we will have: - #data.num_samples() == count_samples(ibegin,iend) - #data.nr() == R - #data.nc() == C - #data.k() == K This results in a tensor concatenation along the sample dimension. !*/ }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_DNn_INPUT_ABSTRACT_H_ ================================================ FILE: dlib/dnn/layers.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNn_LAYERS_H_ #define DLIB_DNn_LAYERS_H_ #include "layers_abstract.h" #include "../cuda/tensor.h" #include "core.h" #include #include #include "../rand.h" #include "../string.h" #include "../cuda/tensor_tools.h" #include "../vectorstream.h" #include "utilities.h" #include "../cuda/operation_mode.h" #include namespace dlib { // ---------------------------------------------------------------------------------------- struct num_con_outputs { num_con_outputs(unsigned long n) : num_outputs(n) {} unsigned long num_outputs; }; template < long _num_filters, long _nr, long _nc, int _stride_y, int _stride_x, int _padding_y = _stride_y!=1? 0 : _nr/2, int _padding_x = _stride_x!=1? 0 : _nc/2 > class con_ { public: static_assert(_num_filters > 0, "The number of filters must be > 0"); static_assert(_nr >= 0, "The number of rows in a filter must be >= 0"); static_assert(_nc >= 0, "The number of columns in a filter must be >= 0"); static_assert(_stride_y > 0, "The filter stride must be > 0"); static_assert(_stride_x > 0, "The filter stride must be > 0"); static_assert(_nr==0 || (0 <= _padding_y && _padding_y < _nr), "The padding must be smaller than the filter size."); static_assert(_nc==0 || (0 <= _padding_x && _padding_x < _nc), "The padding must be smaller than the filter size."); static_assert(_nr!=0 || 0 == _padding_y, "If _nr==0 then the padding must be set to 0 as well."); static_assert(_nc!=0 || 0 == _padding_x, "If _nr==0 then the padding must be set to 0 as well."); con_( num_con_outputs o ) : learning_rate_multiplier(1), weight_decay_multiplier(1), bias_learning_rate_multiplier(1), bias_weight_decay_multiplier(0), num_filters_(o.num_outputs), padding_y_(_padding_y), padding_x_(_padding_x), use_bias(true), use_relu(false) { DLIB_CASSERT(num_filters_ > 0); } con_() : con_(num_con_outputs(_num_filters)) {} long num_filters() const { return num_filters_; } long nr() const { if (_nr==0) return filters.nr(); else return _nr; } long nc() const { if (_nc==0) return filters.nc(); else return _nc; } long stride_y() const { return _stride_y; } long stride_x() const { return _stride_x; } long padding_y() const { return padding_y_; } long padding_x() const { return padding_x_; } void set_num_filters(long num) { DLIB_CASSERT(num > 0); if (num != num_filters_) { DLIB_CASSERT(get_layer_params().size() == 0, "You can't change the number of filters in con_ if the parameter tensor has already been allocated."); num_filters_ = num; } } double get_learning_rate_multiplier () const { return learning_rate_multiplier; } double get_weight_decay_multiplier () const { return weight_decay_multiplier; } void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; } double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; } void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } bool relu_is_disabled() const { return !use_relu; } void disable_relu() { use_relu = false; } void enable_relu() { use_relu = true; } bool bias_is_disabled() const { return !use_bias; } void disable_bias() { if (use_bias == false) return; use_bias = false; if (params.size() == 0) return; DLIB_CASSERT(params.size() == filters.size() + num_filters_); auto temp = params; params.set_size(params.size() - num_filters_); std::copy(temp.begin(), temp.end() - num_filters_, params.begin()); biases = alias_tensor(); } void enable_bias() { if (use_bias == true) return; use_bias = true; if (params.size() == 0) return; DLIB_CASSERT(params.size() == filters.size()); auto temp = params; params.set_size(params.size() + num_filters_); std::copy(temp.begin(), temp.end(), params.begin()); biases = alias_tensor(1, num_filters_); biases(params, filters.size()) = 0; } inline dpoint map_input_to_output ( dpoint p ) const { p.x() = (p.x()+padding_x()-nc()/2)/stride_x(); p.y() = (p.y()+padding_y()-nr()/2)/stride_y(); return p; } inline dpoint map_output_to_input ( dpoint p ) const { p.x() = p.x()*stride_x() - padding_x() + nc()/2; p.y() = p.y()*stride_y() - padding_y() + nr()/2; return p; } con_ ( const con_& item ) : params(item.params), filters(item.filters), biases(item.biases), learning_rate_multiplier(item.learning_rate_multiplier), weight_decay_multiplier(item.weight_decay_multiplier), bias_learning_rate_multiplier(item.bias_learning_rate_multiplier), bias_weight_decay_multiplier(item.bias_weight_decay_multiplier), num_filters_(item.num_filters_), padding_y_(item.padding_y_), padding_x_(item.padding_x_), use_bias(item.use_bias), use_relu(item.use_relu) { // this->conv is non-copyable and basically stateless, so we have to write our // own copy to avoid trying to copy it and getting an error. } con_& operator= ( const con_& item ) { if (this == &item) return *this; // this->conv is non-copyable and basically stateless, so we have to write our // own copy to avoid trying to copy it and getting an error. params = item.params; filters = item.filters; biases = item.biases; padding_y_ = item.padding_y_; padding_x_ = item.padding_x_; learning_rate_multiplier = item.learning_rate_multiplier; weight_decay_multiplier = item.weight_decay_multiplier; bias_learning_rate_multiplier = item.bias_learning_rate_multiplier; bias_weight_decay_multiplier = item.bias_weight_decay_multiplier; num_filters_ = item.num_filters_; use_bias = item.use_bias; use_relu = item.use_relu; return *this; } template void setup (const SUBNET& sub) { const long filt_nr = _nr!=0 ? _nr : sub.get_output().nr(); const long filt_nc = _nc!=0 ? _nc : sub.get_output().nc(); long num_inputs = filt_nr*filt_nc*sub.get_output().k(); long num_outputs = num_filters_; // allocate params for the filters and also for the filter bias values. params.set_size(num_inputs*num_filters_ + static_cast(use_bias) * num_filters_); dlib::rand rnd(std::rand()); randomize_parameters(params, num_inputs+num_outputs, rnd); filters = alias_tensor(num_filters_, sub.get_output().k(), filt_nr, filt_nc); if (use_bias) { biases = alias_tensor(1,num_filters_); // set the initial bias values to zero biases(params,filters.size()) = 0; } } template void forward(const SUBNET& sub, resizable_tensor& output) { conv.setup(sub.get_output(), filters(params,0), _stride_y, _stride_x, padding_y_, padding_x_); if (use_bias) { conv(false, output, sub.get_output(), filters(params,0), biases(params, filters.size()), use_relu); } else { conv(false, output, sub.get_output(), filters(params,0)); } } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) { conv.get_gradient_for_data (true, gradient_input, filters(params,0), sub.get_gradient_input()); // no point computing the parameter gradients if they won't be used. if (learning_rate_multiplier != 0) { auto filt = filters(params_grad,0); conv.get_gradient_for_filters (false, gradient_input, sub.get_output(), filt); if (use_bias) { auto b = biases(params_grad, filters.size()); tt::assign_conv_bias_gradient(b, gradient_input); } } } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const con_& item, std::ostream& out) { serialize("con_6", out); serialize(item.params, out); serialize(item.num_filters_, out); serialize(_nr, out); serialize(_nc, out); serialize(_stride_y, out); serialize(_stride_x, out); serialize(item.padding_y_, out); serialize(item.padding_x_, out); serialize(item.filters, out); serialize(item.biases, out); serialize(item.learning_rate_multiplier, out); serialize(item.weight_decay_multiplier, out); serialize(item.bias_learning_rate_multiplier, out); serialize(item.bias_weight_decay_multiplier, out); serialize(item.use_bias, out); serialize(item.use_relu, out); } friend void deserialize(con_& item, std::istream& in) { std::string version; deserialize(version, in); long nr; long nc; int stride_y; int stride_x; if (version == "con_4" || version == "con_5" || version == "con_6") { deserialize(item.params, in); deserialize(item.num_filters_, in); deserialize(nr, in); deserialize(nc, in); deserialize(stride_y, in); deserialize(stride_x, in); deserialize(item.padding_y_, in); deserialize(item.padding_x_, in); deserialize(item.filters, in); deserialize(item.biases, in); deserialize(item.learning_rate_multiplier, in); deserialize(item.weight_decay_multiplier, in); deserialize(item.bias_learning_rate_multiplier, in); deserialize(item.bias_weight_decay_multiplier, in); if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_"); if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_"); if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_"); if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_"); if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_"); if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_"); if (version == "con_5" || version == "con_6") { deserialize(item.use_bias, in); } if (version == "con_6") { deserialize(item.use_relu, in); } } else { throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_."); } } friend std::ostream& operator<<(std::ostream& out, const con_& item) { out << "con\t (" << "num_filters="<\n"; out << mat(item.params); out << "\n"; } private: resizable_tensor params; alias_tensor filters, biases; tt::tensor_conv conv; double learning_rate_multiplier; double weight_decay_multiplier; double bias_learning_rate_multiplier; double bias_weight_decay_multiplier; long num_filters_; // These are here only because older versions of con (which you might encounter // serialized to disk) used different padding settings. int padding_y_; int padding_x_; bool use_bias; bool use_relu; }; template < long num_filters, long nr, long nc, int stride_y, int stride_x, typename SUBNET > using con = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < long _num_filters, long _nr, long _nc, int _stride_y, int _stride_x, int _padding_y = _stride_y!=1? 0 : _nr/2, int _padding_x = _stride_x!=1? 0 : _nc/2 > class cont_ { public: static_assert(_num_filters > 0, "The number of filters must be > 0"); static_assert(_nr > 0, "The number of rows in a filter must be > 0"); static_assert(_nc > 0, "The number of columns in a filter must be > 0"); static_assert(_stride_y > 0, "The filter stride must be > 0"); static_assert(_stride_x > 0, "The filter stride must be > 0"); static_assert(0 <= _padding_y && _padding_y < _nr, "The padding must be smaller than the filter size."); static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size."); cont_( num_con_outputs o ) : learning_rate_multiplier(1), weight_decay_multiplier(1), bias_learning_rate_multiplier(1), bias_weight_decay_multiplier(0), num_filters_(o.num_outputs), padding_y_(_padding_y), padding_x_(_padding_x), use_bias(true) { DLIB_CASSERT(num_filters_ > 0); } cont_() : cont_(num_con_outputs(_num_filters)) {} long num_filters() const { return num_filters_; } long nr() const { return _nr; } long nc() const { return _nc; } long stride_y() const { return _stride_y; } long stride_x() const { return _stride_x; } long padding_y() const { return padding_y_; } long padding_x() const { return padding_x_; } void set_num_filters(long num) { DLIB_CASSERT(num > 0); if (num != num_filters_) { DLIB_CASSERT(get_layer_params().size() == 0, "You can't change the number of filters in cont_ if the parameter tensor has already been allocated."); num_filters_ = num; } } double get_learning_rate_multiplier () const { return learning_rate_multiplier; } double get_weight_decay_multiplier () const { return weight_decay_multiplier; } void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; } double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; } void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } void disable_bias() { use_bias = false; } bool bias_is_disabled() const { return !use_bias; } inline dpoint map_output_to_input ( dpoint p ) const { p.x() = (p.x()+padding_x()-nc()/2)/stride_x(); p.y() = (p.y()+padding_y()-nr()/2)/stride_y(); return p; } inline dpoint map_input_to_output ( dpoint p ) const { p.x() = p.x()*stride_x() - padding_x() + nc()/2; p.y() = p.y()*stride_y() - padding_y() + nr()/2; return p; } cont_ ( const cont_& item ) : params(item.params), filters(item.filters), biases(item.biases), learning_rate_multiplier(item.learning_rate_multiplier), weight_decay_multiplier(item.weight_decay_multiplier), bias_learning_rate_multiplier(item.bias_learning_rate_multiplier), bias_weight_decay_multiplier(item.bias_weight_decay_multiplier), num_filters_(item.num_filters_), padding_y_(item.padding_y_), padding_x_(item.padding_x_), use_bias(item.use_bias) { // this->conv is non-copyable and basically stateless, so we have to write our // own copy to avoid trying to copy it and getting an error. } cont_& operator= ( const cont_& item ) { if (this == &item) return *this; // this->conv is non-copyable and basically stateless, so we have to write our // own copy to avoid trying to copy it and getting an error. params = item.params; filters = item.filters; biases = item.biases; padding_y_ = item.padding_y_; padding_x_ = item.padding_x_; learning_rate_multiplier = item.learning_rate_multiplier; weight_decay_multiplier = item.weight_decay_multiplier; bias_learning_rate_multiplier = item.bias_learning_rate_multiplier; bias_weight_decay_multiplier = item.bias_weight_decay_multiplier; num_filters_ = item.num_filters_; use_bias = item.use_bias; return *this; } template void setup (const SUBNET& sub) { long num_inputs = _nr*_nc*sub.get_output().k(); long num_outputs = num_filters_; // allocate params for the filters and also for the filter bias values. params.set_size(num_inputs*num_filters_ + num_filters_ * static_cast(use_bias)); dlib::rand rnd(std::rand()); randomize_parameters(params, num_inputs+num_outputs, rnd); filters = alias_tensor(sub.get_output().k(), num_filters_, _nr, _nc); if (use_bias) { biases = alias_tensor(1,num_filters_); // set the initial bias values to zero biases(params,filters.size()) = 0; } } template void forward(const SUBNET& sub, resizable_tensor& output) { auto filt = filters(params,0); unsigned int gnr = _stride_y * (sub.get_output().nr() - 1) + filt.nr() - 2 * padding_y_; unsigned int gnc = _stride_x * (sub.get_output().nc() - 1) + filt.nc() - 2 * padding_x_; unsigned int gnsamps = sub.get_output().num_samples(); unsigned int gk = filt.k(); output.set_size(gnsamps,gk,gnr,gnc); conv.setup(output,filt,_stride_y,_stride_x,padding_y_,padding_x_); conv.get_gradient_for_data(false, sub.get_output(),filt,output); if (use_bias) { tt::add(1,output,1,biases(params,filters.size())); } } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) { auto filt = filters(params,0); conv(true, sub.get_gradient_input(),gradient_input, filt); // no point computing the parameter gradients if they won't be used. if (learning_rate_multiplier != 0) { auto filt = filters(params_grad,0); conv.get_gradient_for_filters (false, sub.get_output(),gradient_input, filt); if (use_bias) { auto b = biases(params_grad, filters.size()); tt::assign_conv_bias_gradient(b, gradient_input); } } } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const cont_& item, std::ostream& out) { serialize("cont_2", out); serialize(item.params, out); serialize(item.num_filters_, out); serialize(_nr, out); serialize(_nc, out); serialize(_stride_y, out); serialize(_stride_x, out); serialize(item.padding_y_, out); serialize(item.padding_x_, out); serialize(item.filters, out); serialize(item.biases, out); serialize(item.learning_rate_multiplier, out); serialize(item.weight_decay_multiplier, out); serialize(item.bias_learning_rate_multiplier, out); serialize(item.bias_weight_decay_multiplier, out); serialize(item.use_bias, out); } friend void deserialize(cont_& item, std::istream& in) { std::string version; deserialize(version, in); long nr; long nc; int stride_y; int stride_x; if (version == "cont_1" || version == "cont_2") { deserialize(item.params, in); deserialize(item.num_filters_, in); deserialize(nr, in); deserialize(nc, in); deserialize(stride_y, in); deserialize(stride_x, in); deserialize(item.padding_y_, in); deserialize(item.padding_x_, in); deserialize(item.filters, in); deserialize(item.biases, in); deserialize(item.learning_rate_multiplier, in); deserialize(item.weight_decay_multiplier, in); deserialize(item.bias_learning_rate_multiplier, in); deserialize(item.bias_weight_decay_multiplier, in); if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_"); if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_"); if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_"); if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_"); if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_"); if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_"); if (version == "cont_2") { deserialize(item.use_bias, in); } } else { throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_."); } } friend std::ostream& operator<<(std::ostream& out, const cont_& item) { out << "cont\t (" << "num_filters="<\n"; out << mat(item.params); out << "\n"; } private: resizable_tensor params; alias_tensor filters, biases; tt::tensor_conv conv; double learning_rate_multiplier; double weight_decay_multiplier; double bias_learning_rate_multiplier; double bias_weight_decay_multiplier; long num_filters_; int padding_y_; int padding_x_; bool use_bias; }; template < long num_filters, long nr, long nc, int stride_y, int stride_x, typename SUBNET > using cont = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < int scale_y, int scale_x > class upsample_ { public: static_assert(scale_y >= 1, "upsampling scale factor can't be less than 1."); static_assert(scale_x >= 1, "upsampling scale factor can't be less than 1."); upsample_() { } template void setup (const SUBNET& /*sub*/) { } template void forward(const SUBNET& sub, resizable_tensor& output) { output.set_size( sub.get_output().num_samples(), sub.get_output().k(), scale_y*sub.get_output().nr(), scale_x*sub.get_output().nc()); tt::resize_bilinear(output, sub.get_output()); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { tt::resize_bilinear_gradient(sub.get_gradient_input(), gradient_input); } inline dpoint map_input_to_output (dpoint p) const { p.x() = p.x()*scale_x; p.y() = p.y()*scale_y; return p; } inline dpoint map_output_to_input (dpoint p) const { p.x() = p.x()/scale_x; p.y() = p.y()/scale_y; return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const upsample_& /*item*/, std::ostream& out) { serialize("upsample_", out); serialize(scale_y, out); serialize(scale_x, out); } friend void deserialize(upsample_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "upsample_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::upsample_."); int _scale_y; int _scale_x; deserialize(_scale_y, in); deserialize(_scale_x, in); if (_scale_y != scale_y || _scale_x != scale_x) throw serialization_error("Wrong scale found while deserializing dlib::upsample_"); } friend std::ostream& operator<<(std::ostream& out, const upsample_& /*item*/) { out << "upsample\t (" << "scale_y="<\n"; } private: resizable_tensor params; }; template < int scale, typename SUBNET > using upsample = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < long NR_, long NC_ > class resize_to_ { public: static_assert(NR_ >= 1, "NR resize parameter can't be less than 1."); static_assert(NC_ >= 1, "NC resize parameter can't be less than 1."); resize_to_() { } template void setup (const SUBNET& /*sub*/) { } template void forward(const SUBNET& sub, resizable_tensor& output) { scale_y = (double)NR_/(double)sub.get_output().nr(); scale_x = (double)NC_/(double)sub.get_output().nc(); output.set_size( sub.get_output().num_samples(), sub.get_output().k(), NR_, NC_); tt::resize_bilinear(output, sub.get_output()); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { tt::resize_bilinear_gradient(sub.get_gradient_input(), gradient_input); } inline dpoint map_input_to_output (dpoint p) const { p.x() = p.x()*scale_x; p.y() = p.y()*scale_y; return p; } inline dpoint map_output_to_input (dpoint p) const { p.x() = p.x()/scale_x; p.y() = p.y()/scale_y; return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const resize_to_& item, std::ostream& out) { serialize("resize_to_", out); serialize(NR_, out); serialize(NC_, out); serialize(item.scale_y, out); serialize(item.scale_x, out); } friend void deserialize(resize_to_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "resize_to_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::resize_to_."); long _nr; long _nc; deserialize(_nr, in); deserialize(_nc, in); deserialize(item.scale_y, in); deserialize(item.scale_x, in); if (_nr != NR_ || _nc != NC_) throw serialization_error("Wrong size found while deserializing dlib::resize_to_"); } friend std::ostream& operator<<(std::ostream& out, const resize_to_& /*item*/) { out << "resize_to (" << "nr=" << NR_ << ", nc=" << NC_ << ")"; return out; } friend void to_xml(const resize_to_& /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; double scale_y; double scale_x; }; // end of class resize_to_ template < long NR, long NC, typename SUBNET > using resize_to = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template class reshape_to_ { public: explicit reshape_to_() : output_k(k_), output_nr(nr_), output_nc(nc_) { static_assert(k_ == -1 || k_ > 0, "Output k must be positive or -1"); static_assert(nr_ == -1 || nr_ > 0, "Output nr must be positive or -1"); static_assert(nc_ == -1 || nc_ > 0, "Output nc must be positive or -1"); input_k = input_nr = input_nc = 0; needs_rescale = false; } // Getters for dimensions long get_output_k() const { return output_k; } long get_output_nr() const { return output_nr; } long get_output_nc() const { return output_nc; } // Setters for dimensions void set_output_k(long k) { DLIB_CASSERT(k == -1 || k > 0, "Output k must be positive or -1 to keep original dimension"); output_k = k; } void set_output_nr(long nr) { DLIB_CASSERT(nr == -1 || nr > 0, "output nr must be positive or -1 to keep original dimension"); output_nr = nr; } void set_output_nc(long nc) { DLIB_CASSERT(nc == -1 || nc > 0, "output nc must be positive or -1 to keep original dimension"); output_nc = nc; } template void setup(const SUBNET& sub) { const auto& input = sub.get_output(); input_k = input.k(); input_nr = input.nr(); input_nc = input.nc(); // Calculate output dimensions using input dims where target is -1 if (k_ == -1) output_k = input_k; if (nr_ == -1) output_nr = input_nr; if (nc_ == -1) output_nc = input_nc; // Check if this is well a pure reshape long input_elements = input_k * input_nr * input_nc; long output_elements = output_k * output_nr * output_nc; if (input_elements != output_elements && input_k == output_k) needs_rescale = true; DLIB_CASSERT(input_elements == output_elements || needs_rescale, "Cannot reshape tensor of " << input_elements << " elements into shape with " << output_elements << " elements. " << "For spatial rescaling, the channel dimension (k) must remain constant."); } template void forward(const SUBNET& sub, resizable_tensor& output) { // Set the output size (always preserving batch dimension) const tensor& input = sub.get_output(); output.set_size(input.num_samples(), output_k, output_nr, output_nc); if (!needs_rescale) { // Create an alias of the input tensor with the output shape alias_tensor input_alias(output.num_samples(), output_k, output_nr, output_nc); // Get a view of the input tensor with the new shape auto input_reshaped = input_alias(const_cast(input), 0); // Copy the view to the output tensor tt::copy_tensor(false, output, 0, input_reshaped, 0, input_reshaped.k()); } else { // Only spatial dimensions need to be resized tt::resize_bilinear(output, input); } } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { auto& grad = sub.get_gradient_input(); if (!needs_rescale) { // Create an alias of the gradient tensor with the original input shape alias_tensor grad_alias(grad.num_samples(), grad.k(), grad.nr(), grad.nc()); // Get a view of the input gradient with the required shape auto grad_reshaped = grad_alias(const_cast(gradient_input), 0); // Copy the view to the output gradient tt::copy_tensor(true, grad, 0, grad_reshaped, 0, grad_reshaped.k()); } else { // Only spatial dimensions were resized tt::resize_bilinear_gradient(grad, gradient_input); } } // Mapping functions for coordinate transformations inline dpoint map_input_to_output(const dpoint& p) const { double scale_x = output_nc / static_cast(input_nc); double scale_y = output_nr / static_cast(input_nr); return dpoint(p.x() * scale_x, p.y() * scale_y); } inline dpoint map_output_to_input(const dpoint& p) const { double scale_x = input_nc / static_cast(output_nc); double scale_y = input_nr / static_cast(output_nr); return dpoint(p.x() * scale_x, p.y() * scale_y); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const reshape_to_& item, std::ostream& out) { serialize("reshape_to_", out); serialize(item.input_k, out); serialize(item.input_nr, out); serialize(item.input_nc, out); serialize(item.output_k, out); serialize(item.output_nr, out); serialize(item.output_nc, out); serialize(item.needs_rescale, out); } friend void deserialize(reshape_to_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "reshape_to_") throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::reshape_to_."); deserialize(item.input_k, in); deserialize(item.input_nr, in); deserialize(item.input_nc, in); deserialize(item.output_k, in); deserialize(item.output_nr, in); deserialize(item.output_nc, in); deserialize(item.needs_rescale, in); } friend std::ostream& operator<<(std::ostream& out, const reshape_to_& item) { out << "reshape_to ("; out << "k=" << std::to_string(item.output_k); out << ", nr=" << std::to_string(item.output_nr); out << ", nc=" << std::to_string(item.output_nc); out << ", mode=" << (item.needs_rescale ? "spatial_rescale" : "pure_reshape"); out << ")"; return out; } friend void to_xml(const reshape_to_& item, std::ostream& out) { out << "\n"; } private: long input_k, input_nr, input_nc; // Input dimensions long output_k, output_nr, output_nc; // Output dimensions bool needs_rescale; resizable_tensor params; // No trainable parameters }; template using reshape_to = add_layer, SUBNET>; template using flatten = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < long _nr, long _nc, int _stride_y, int _stride_x, int _padding_y = _stride_y!=1? 0 : _nr/2, int _padding_x = _stride_x!=1? 0 : _nc/2 > class max_pool_ { static_assert(_nr >= 0, "The number of rows in a filter must be >= 0"); static_assert(_nc >= 0, "The number of columns in a filter must be >= 0"); static_assert(_stride_y > 0, "The filter stride must be > 0"); static_assert(_stride_x > 0, "The filter stride must be > 0"); static_assert(0 <= _padding_y && ((_nr==0 && _padding_y == 0) || (_nr!=0 && _padding_y < _nr)), "The padding must be smaller than the filter size, unless the filters size is 0."); static_assert(0 <= _padding_x && ((_nc==0 && _padding_x == 0) || (_nc!=0 && _padding_x < _nc)), "The padding must be smaller than the filter size, unless the filters size is 0."); public: max_pool_( ) : padding_y_(_padding_y), padding_x_(_padding_x) {} long nr() const { return _nr; } long nc() const { return _nc; } long stride_y() const { return _stride_y; } long stride_x() const { return _stride_x; } long padding_y() const { return padding_y_; } long padding_x() const { return padding_x_; } inline dpoint map_input_to_output ( dpoint p ) const { p.x() = (p.x()+padding_x()-nc()/2)/stride_x(); p.y() = (p.y()+padding_y()-nr()/2)/stride_y(); return p; } inline dpoint map_output_to_input ( dpoint p ) const { p.x() = p.x()*stride_x() - padding_x() + nc()/2; p.y() = p.y()*stride_y() - padding_y() + nr()/2; return p; } max_pool_ ( const max_pool_& item ) : padding_y_(item.padding_y_), padding_x_(item.padding_x_) { // this->mp is non-copyable so we have to write our own copy to avoid trying to // copy it and getting an error. } max_pool_& operator= ( const max_pool_& item ) { if (this == &item) return *this; padding_y_ = item.padding_y_; padding_x_ = item.padding_x_; // this->mp is non-copyable so we have to write our own copy to avoid trying to // copy it and getting an error. return *this; } template void setup (const SUBNET& /*sub*/) { } template void forward(const SUBNET& sub, resizable_tensor& output) { mp.setup_max_pooling(_nr!=0?_nr:sub.get_output().nr(), _nc!=0?_nc:sub.get_output().nc(), _stride_y, _stride_x, padding_y_, padding_x_); mp(output, sub.get_output()); } template void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { mp.setup_max_pooling(_nr!=0?_nr:sub.get_output().nr(), _nc!=0?_nc:sub.get_output().nc(), _stride_y, _stride_x, padding_y_, padding_x_); mp.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input()); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const max_pool_& item, std::ostream& out) { serialize("max_pool_2", out); serialize(_nr, out); serialize(_nc, out); serialize(_stride_y, out); serialize(_stride_x, out); serialize(item.padding_y_, out); serialize(item.padding_x_, out); } friend void deserialize(max_pool_& item, std::istream& in) { std::string version; deserialize(version, in); long nr; long nc; int stride_y; int stride_x; if (version == "max_pool_2") { deserialize(nr, in); deserialize(nc, in); deserialize(stride_y, in); deserialize(stride_x, in); deserialize(item.padding_y_, in); deserialize(item.padding_x_, in); } else { throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::max_pool_."); } if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::max_pool_"); if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::max_pool_"); if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::max_pool_"); if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_"); if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_"); if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_"); } friend std::ostream& operator<<(std::ostream& out, const max_pool_& item) { out << "max_pool (" << "nr="<<_nr << ", nc="<<_nc << ", stride_y="<<_stride_y << ", stride_x="<<_stride_x << ", padding_y="<\n"; } private: tt::pooling mp; resizable_tensor params; int padding_y_; int padding_x_; }; template < long nr, long nc, int stride_y, int stride_x, typename SUBNET > using max_pool = add_layer, SUBNET>; template < typename SUBNET > using max_pool_everything = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < long _nr, long _nc, int _stride_y, int _stride_x, int _padding_y = _stride_y!=1? 0 : _nr/2, int _padding_x = _stride_x!=1? 0 : _nc/2 > class avg_pool_ { public: static_assert(_nr >= 0, "The number of rows in a filter must be >= 0"); static_assert(_nc >= 0, "The number of columns in a filter must be >= 0"); static_assert(_stride_y > 0, "The filter stride must be > 0"); static_assert(_stride_x > 0, "The filter stride must be > 0"); static_assert(0 <= _padding_y && ((_nr==0 && _padding_y == 0) || (_nr!=0 && _padding_y < _nr)), "The padding must be smaller than the filter size, unless the filters size is 0."); static_assert(0 <= _padding_x && ((_nc==0 && _padding_x == 0) || (_nc!=0 && _padding_x < _nc)), "The padding must be smaller than the filter size, unless the filters size is 0."); avg_pool_( ) : padding_y_(_padding_y), padding_x_(_padding_x) {} long nr() const { return _nr; } long nc() const { return _nc; } long stride_y() const { return _stride_y; } long stride_x() const { return _stride_x; } long padding_y() const { return padding_y_; } long padding_x() const { return padding_x_; } inline dpoint map_input_to_output ( dpoint p ) const { p.x() = (p.x()+padding_x()-nc()/2)/stride_x(); p.y() = (p.y()+padding_y()-nr()/2)/stride_y(); return p; } inline dpoint map_output_to_input ( dpoint p ) const { p.x() = p.x()*stride_x() - padding_x() + nc()/2; p.y() = p.y()*stride_y() - padding_y() + nr()/2; return p; } avg_pool_ ( const avg_pool_& item ) : padding_y_(item.padding_y_), padding_x_(item.padding_x_) { // this->ap is non-copyable so we have to write our own copy to avoid trying to // copy it and getting an error. } avg_pool_& operator= ( const avg_pool_& item ) { if (this == &item) return *this; padding_y_ = item.padding_y_; padding_x_ = item.padding_x_; // this->ap is non-copyable so we have to write our own copy to avoid trying to // copy it and getting an error. return *this; } template void setup (const SUBNET& /*sub*/) { } template void forward(const SUBNET& sub, resizable_tensor& output) { ap.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(), _nc!=0?_nc:sub.get_output().nc(), _stride_y, _stride_x, padding_y_, padding_x_); ap(output, sub.get_output()); } template void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { ap.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(), _nc!=0?_nc:sub.get_output().nc(), _stride_y, _stride_x, padding_y_, padding_x_); ap.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input()); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const avg_pool_& item, std::ostream& out) { serialize("avg_pool_2", out); serialize(_nr, out); serialize(_nc, out); serialize(_stride_y, out); serialize(_stride_x, out); serialize(item.padding_y_, out); serialize(item.padding_x_, out); } friend void deserialize(avg_pool_& item, std::istream& in) { std::string version; deserialize(version, in); long nr; long nc; int stride_y; int stride_x; if (version == "avg_pool_2") { deserialize(nr, in); deserialize(nc, in); deserialize(stride_y, in); deserialize(stride_x, in); deserialize(item.padding_y_, in); deserialize(item.padding_x_, in); } else { throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::avg_pool_."); } if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::avg_pool_"); if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::avg_pool_"); if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::avg_pool_"); if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_"); if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_"); if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_"); } friend std::ostream& operator<<(std::ostream& out, const avg_pool_& item) { out << "avg_pool (" << "nr="<<_nr << ", nc="<<_nc << ", stride_y="<<_stride_y << ", stride_x="<<_stride_x << ", padding_y="<\n"; } private: tt::pooling ap; resizable_tensor params; int padding_y_; int padding_x_; }; template < long nr, long nc, int stride_y, int stride_x, typename SUBNET > using avg_pool = add_layer, SUBNET>; template < typename SUBNET > using avg_pool_everything = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- const double DEFAULT_LAYER_NORM_EPS = 1e-5; class layer_norm_ { public: explicit layer_norm_( double eps_ = DEFAULT_LAYER_NORM_EPS ) : learning_rate_multiplier(1), weight_decay_multiplier(0), bias_learning_rate_multiplier(1), bias_weight_decay_multiplier(1), eps(eps_) { } double get_eps() const { return eps; } double get_learning_rate_multiplier () const { return learning_rate_multiplier; } double get_weight_decay_multiplier () const { return weight_decay_multiplier; } void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; } double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; } void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } template void setup (const SUBNET& sub) { gamma = alias_tensor(1, sub.get_output().k()); beta = gamma; params.set_size(gamma.size()+beta.size()); gamma(params,0) = 1; beta(params,gamma.size()) = 0; } template void forward(const SUBNET& sub, resizable_tensor& output) { auto g = gamma(params,0); auto b = beta(params,gamma.size()); tt::layer_normalize(eps, output, means, invstds, sub.get_output(), g, b); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) { auto g = gamma(params, 0); auto g_grad = gamma(params_grad, 0); auto b_grad = beta(params_grad, gamma.size()); tt::layer_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad, dmeans, dvars); } const tensor& get_layer_params() const { return params; }; tensor& get_layer_params() { return params; }; friend void serialize(const layer_norm_& item, std::ostream& out) { serialize("layer_norm_", out); serialize(item.params, out); serialize(item.gamma, out); serialize(item.beta, out); serialize(item.means, out); serialize(item.invstds, out); serialize(item.learning_rate_multiplier, out); serialize(item.weight_decay_multiplier, out); serialize(item.bias_learning_rate_multiplier, out); serialize(item.bias_weight_decay_multiplier, out); serialize(item.eps, out); } friend void deserialize(layer_norm_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "layer_norm_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::layer_norm_."); deserialize(item.params, in); deserialize(item.gamma, in); deserialize(item.beta, in); deserialize(item.means, in); deserialize(item.invstds, in); deserialize(item.learning_rate_multiplier, in); deserialize(item.weight_decay_multiplier, in); deserialize(item.bias_learning_rate_multiplier, in); deserialize(item.bias_weight_decay_multiplier, in); deserialize(item.eps, in); } friend std::ostream& operator<<(std::ostream& out, const layer_norm_& item) { out << "layer_norm"; out << " eps="<\n"; out << mat(item.params); out << "\n"; } private: resizable_tensor params; alias_tensor gamma, beta; resizable_tensor means, invstds; resizable_tensor dmeans, dvars; double learning_rate_multiplier; double weight_decay_multiplier; double bias_learning_rate_multiplier; double bias_weight_decay_multiplier; double eps; }; template using layer_norm = add_layer; // ---------------------------------------------------------------------------------------- const double DEFAULT_RMS_NORM_EPS = 1e-5; class rms_norm_ { public: explicit rms_norm_( double eps_ = DEFAULT_RMS_NORM_EPS ) : learning_rate_multiplier(1), weight_decay_multiplier(0), bias_learning_rate_multiplier(1), bias_weight_decay_multiplier(1), eps(eps_) { } double get_eps() const { return eps; } double get_learning_rate_multiplier() const { return learning_rate_multiplier; } double get_weight_decay_multiplier() const { return weight_decay_multiplier; } void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } double get_bias_learning_rate_multiplier() const { return bias_learning_rate_multiplier; } double get_bias_weight_decay_multiplier() const { return bias_weight_decay_multiplier; } void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } inline dpoint map_input_to_output(const dpoint& p) const { return p; } inline dpoint map_output_to_input(const dpoint& p) const { return p; } template void setup(const SUBNET& sub) { gamma = alias_tensor(1, sub.get_output().k()); params.set_size(gamma.size()); gamma(params, 0) = 1; } template void forward(const SUBNET& sub, resizable_tensor& output) { auto g = gamma(params, 0); tt::rms_normalize(eps, output, scale, sub.get_output(), g); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) { auto g = gamma(params, 0); auto g_grad = gamma(params_grad, 0); tt::rms_normalize_gradient(gradient_input, scale, sub.get_output(), g, sub.get_gradient_input(), g_grad, dscale); } const tensor& get_layer_params() const { return params; }; tensor& get_layer_params() { return params; }; friend void serialize(const rms_norm_& item, std::ostream& out) { serialize("rms_norm_", out); serialize(item.params, out); serialize(item.gamma, out); serialize(item.learning_rate_multiplier, out); serialize(item.weight_decay_multiplier, out); serialize(item.bias_learning_rate_multiplier, out); serialize(item.bias_weight_decay_multiplier, out); serialize(item.eps, out); } friend void deserialize(rms_norm_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "rms_norm_") throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::rms_norm_."); deserialize(item.params, in); deserialize(item.gamma, in); deserialize(item.learning_rate_multiplier, in); deserialize(item.weight_decay_multiplier, in); deserialize(item.bias_learning_rate_multiplier, in); deserialize(item.bias_weight_decay_multiplier, in); deserialize(item.eps, in); } friend std::ostream& operator<<(std::ostream& out, const rms_norm_& item) { out << "rms_norm"; out << " (eps=" << item.eps << ")"; out << " learning_rate_mult=" << item.learning_rate_multiplier; out << " weight_decay_mult=" << item.weight_decay_multiplier; out << " bias_learning_rate_mult=" << item.bias_learning_rate_multiplier; out << " bias_weight_decay_mult=" << item.bias_weight_decay_multiplier; return out; } friend void to_xml(const rms_norm_& item, std::ostream& out) { out << "\n"; out << mat(item.params); out << "\n"; } private: resizable_tensor params; alias_tensor gamma; resizable_tensor scale; resizable_tensor dscale; double learning_rate_multiplier; double weight_decay_multiplier; double bias_learning_rate_multiplier; double bias_weight_decay_multiplier; double eps; }; template using rms_norm = add_layer; // ---------------------------------------------------------------------------------------- enum layer_mode { CONV_MODE = 0, FC_MODE = 1 }; const double DEFAULT_BATCH_NORM_EPS = 0.0001; template < layer_mode mode > class bn_ { public: explicit bn_( unsigned long window_size, double eps_ = DEFAULT_BATCH_NORM_EPS ) : num_updates(0), running_stats_window_size(window_size), learning_rate_multiplier(1), weight_decay_multiplier(0), bias_learning_rate_multiplier(1), bias_weight_decay_multiplier(1), eps(eps_) { DLIB_CASSERT(window_size > 0, "The batch normalization running stats window size can't be 0."); } bn_() : bn_(100) {} layer_mode get_mode() const { return mode; } unsigned long get_running_stats_window_size () const { return running_stats_window_size; } void set_running_stats_window_size (unsigned long new_window_size ) { DLIB_CASSERT(new_window_size > 0, "The batch normalization running stats window size can't be 0."); running_stats_window_size = new_window_size; } double get_eps() const { return eps; } double get_learning_rate_multiplier () const { return learning_rate_multiplier; } double get_weight_decay_multiplier () const { return weight_decay_multiplier; } void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; } double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; } void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } template void setup (const SUBNET& sub) { if (mode == FC_MODE) { gamma = alias_tensor(1, sub.get_output().k(), sub.get_output().nr(), sub.get_output().nc()); } else { gamma = alias_tensor(1, sub.get_output().k()); } beta = gamma; params.set_size(gamma.size()+beta.size()); gamma(params,0) = 1; beta(params,gamma.size()) = 0; running_means.copy_size(gamma(params,0)); running_variances.copy_size(gamma(params,0)); running_means = 0; running_variances = 1; num_updates = 0; } template void forward(const SUBNET& sub, resizable_tensor& output) { auto g = gamma(params,0); auto b = beta(params,gamma.size()); if (sub.get_output().num_samples() > 1) { const double decay = 1.0 - num_updates/(num_updates+1.0); ++num_updates; if (num_updates > running_stats_window_size) num_updates = running_stats_window_size; if (mode == FC_MODE) tt::batch_normalize(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b); else tt::batch_normalize_conv(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b); } else // we are running in testing mode so we just linearly scale the input tensor. { if (mode == FC_MODE) tt::batch_normalize_inference(eps, output, sub.get_output(), g, b, running_means, running_variances); else tt::batch_normalize_conv_inference(eps, output, sub.get_output(), g, b, running_means, running_variances); } } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) { auto g = gamma(params,0); auto g_grad = gamma(params_grad, 0); auto b_grad = beta(params_grad, gamma.size()); if (mode == FC_MODE) tt::batch_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad ); else tt::batch_normalize_conv_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad ); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const bn_& item, std::ostream& out) { if (mode == CONV_MODE) serialize("bn_con2", out); else // if FC_MODE serialize("bn_fc2", out); serialize(item.params, out); serialize(item.gamma, out); serialize(item.beta, out); serialize(item.means, out); serialize(item.invstds, out); serialize(item.running_means, out); serialize(item.running_variances, out); serialize(item.num_updates, out); serialize(item.running_stats_window_size, out); serialize(item.learning_rate_multiplier, out); serialize(item.weight_decay_multiplier, out); serialize(item.bias_learning_rate_multiplier, out); serialize(item.bias_weight_decay_multiplier, out); serialize(item.eps, out); } friend void deserialize(bn_& item, std::istream& in) { std::string version; deserialize(version, in); if (mode == CONV_MODE) { if (version != "bn_con2") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_."); } else // must be in FC_MODE { if (version != "bn_fc2") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_."); } deserialize(item.params, in); deserialize(item.gamma, in); deserialize(item.beta, in); deserialize(item.means, in); deserialize(item.invstds, in); deserialize(item.running_means, in); deserialize(item.running_variances, in); deserialize(item.num_updates, in); deserialize(item.running_stats_window_size, in); deserialize(item.learning_rate_multiplier, in); deserialize(item.weight_decay_multiplier, in); deserialize(item.bias_learning_rate_multiplier, in); deserialize(item.bias_weight_decay_multiplier, in); deserialize(item.eps, in); } friend std::ostream& operator<<(std::ostream& out, const bn_& item) { if (mode == CONV_MODE) out << "bn_con "; else out << "bn_fc "; out << " eps="<\n"; out << mat(item.params); if (mode==CONV_MODE) out << "\n"; else out << "\n"; } private: friend class affine_; resizable_tensor params; alias_tensor gamma, beta; resizable_tensor means, running_means; resizable_tensor invstds, running_variances; unsigned long num_updates; unsigned long running_stats_window_size; double learning_rate_multiplier; double weight_decay_multiplier; double bias_learning_rate_multiplier; double bias_weight_decay_multiplier; double eps; }; template using bn_con = add_layer, SUBNET>; template using bn_fc = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- enum fc_bias_mode { FC_HAS_BIAS = 0, FC_NO_BIAS = 1 }; struct num_fc_outputs { num_fc_outputs(unsigned long n) : num_outputs(n) {} unsigned long num_outputs; }; template < unsigned long num_outputs_, fc_bias_mode bias_mode > class fc_ { static_assert(num_outputs_ > 0, "The number of outputs from a fc_ layer must be > 0"); public: fc_(num_fc_outputs o) : num_outputs(o.num_outputs), num_inputs(0), learning_rate_multiplier(1), weight_decay_multiplier(1), bias_learning_rate_multiplier(1), bias_weight_decay_multiplier(0), use_bias(true) {} fc_() : fc_(num_fc_outputs(num_outputs_)) {} double get_learning_rate_multiplier () const { return learning_rate_multiplier; } double get_weight_decay_multiplier () const { return weight_decay_multiplier; } void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; } double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; } void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } void disable_bias() { use_bias = false; } bool bias_is_disabled() const { return !use_bias; } unsigned long get_num_outputs ( ) const { return num_outputs; } void set_num_outputs(long num) { DLIB_CASSERT(num > 0); if (num != (long)num_outputs) { DLIB_CASSERT(get_layer_params().size() == 0, "You can't change the number of filters in fc_ if the parameter tensor has already been allocated."); num_outputs = num; } } fc_bias_mode get_bias_mode ( ) const { return bias_mode; } template void setup (const SUBNET& sub) { num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k(); if (bias_mode == FC_HAS_BIAS && use_bias) params.set_size(num_inputs+1, num_outputs); else params.set_size(num_inputs, num_outputs); dlib::rand rnd(std::rand()); randomize_parameters(params, num_inputs+num_outputs, rnd); weights = alias_tensor(num_inputs, num_outputs); if (bias_mode == FC_HAS_BIAS && use_bias) { biases = alias_tensor(1,num_outputs); // set the initial bias values to zero biases(params,weights.size()) = 0; } } template void forward(const SUBNET& sub, resizable_tensor& output) { DLIB_CASSERT((long)num_inputs == sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k(), "The size of the input tensor to this fc layer doesn't match the size the fc layer was trained with."); output.set_size(sub.get_output().num_samples(), num_outputs); auto w = weights(params, 0); tt::gemm(0,output, 1,sub.get_output(),false, w,false); if (bias_mode == FC_HAS_BIAS && use_bias) { auto b = biases(params, weights.size()); tt::add(1,output,1,b); } } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) { // no point computing the parameter gradients if they won't be used. if (learning_rate_multiplier != 0) { // compute the gradient of the weight parameters. auto pw = weights(params_grad, 0); tt::gemm(0,pw, 1,sub.get_output(),true, gradient_input,false); if (bias_mode == FC_HAS_BIAS && use_bias) { // compute the gradient of the bias parameters. auto pb = biases(params_grad, weights.size()); tt::assign_bias_gradient(pb, gradient_input); } } // compute the gradient for the data auto w = weights(params, 0); tt::gemm(1,sub.get_gradient_input(), 1,gradient_input,false, w,true); } alias_tensor_instance get_weights() { return weights(params, 0); } alias_tensor_const_instance get_weights() const { return weights(params, 0); } alias_tensor_instance get_biases() { static_assert(bias_mode == FC_HAS_BIAS, "This fc_ layer doesn't have a bias vector " "to be retrieved, as per template parameter 'bias_mode'."); return biases(params, weights.size()); } alias_tensor_const_instance get_biases() const { static_assert(bias_mode == FC_HAS_BIAS, "This fc_ layer doesn't have a bias vector " "to be retrieved, as per template parameter 'bias_mode'."); return biases(params, weights.size()); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const fc_& item, std::ostream& out) { serialize("fc_3", out); serialize(item.num_outputs, out); serialize(item.num_inputs, out); serialize(item.params, out); serialize(item.weights, out); serialize(item.biases, out); serialize((int)bias_mode, out); serialize(item.learning_rate_multiplier, out); serialize(item.weight_decay_multiplier, out); serialize(item.bias_learning_rate_multiplier, out); serialize(item.bias_weight_decay_multiplier, out); serialize(item.use_bias, out); } friend void deserialize(fc_& item, std::istream& in) { std::string version; deserialize(version, in); if (version == "fc_2" || version == "fc_3") { deserialize(item.num_outputs, in); deserialize(item.num_inputs, in); deserialize(item.params, in); deserialize(item.weights, in); deserialize(item.biases, in); int bmode = 0; deserialize(bmode, in); if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_"); deserialize(item.learning_rate_multiplier, in); deserialize(item.weight_decay_multiplier, in); deserialize(item.bias_learning_rate_multiplier, in); deserialize(item.bias_weight_decay_multiplier, in); if (version == "fc_3") { deserialize(item.use_bias, in); } } else { throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_."); } } friend std::ostream& operator<<(std::ostream& out, const fc_& item) { if (bias_mode == FC_HAS_BIAS) { out << "fc\t (" << "num_outputs="<\n"; out << mat(item.params); out << "\n"; } else { out << "\n"; out << mat(item.params); out << "\n"; } } private: unsigned long num_outputs; unsigned long num_inputs; resizable_tensor params; alias_tensor weights, biases; double learning_rate_multiplier; double weight_decay_multiplier; double bias_learning_rate_multiplier; double bias_weight_decay_multiplier; bool use_bias; }; template < unsigned long num_outputs, typename SUBNET > using fc = add_layer, SUBNET>; template < unsigned long num_outputs, typename SUBNET > using fc_no_bias = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- enum linear_bias_mode { LINEAR_HAS_BIAS = 0, LINEAR_NO_BIAS = 1 }; template < unsigned long num_outputs_, linear_bias_mode bias_mode_ = LINEAR_HAS_BIAS > class linear_ { static_assert(num_outputs_ > 0, "The number of outputs from a linear_ layer must be > 0"); public: explicit linear_() : num_outputs(num_outputs_), num_inputs(0), learning_rate_multiplier(1), bias_mode(bias_mode_) { } linear_(const linear_& other) : num_outputs(other.num_outputs), num_inputs(other.num_inputs), learning_rate_multiplier(other.learning_rate_multiplier), bias_mode(other.bias_mode), params(other.params), weights(other.weights), biases(other.biases) { } linear_& operator=(const linear_& other) { if (this != &other) { num_outputs = other.num_outputs; num_inputs = other.num_inputs; learning_rate_multiplier = other.learning_rate_multiplier; bias_mode = other.bias_mode; params = other.params; weights = other.weights; biases = other.biases; } return *this; } double get_learning_rate_multiplier() const { return learning_rate_multiplier; } void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } unsigned long get_num_outputs() const { return num_outputs; } void set_num_outputs(long num) { DLIB_CASSERT(num > 0, "The number of outputs must be > 0, but num == " << num); if (num != (long)num_outputs) { DLIB_CASSERT(get_layer_params().size() == 0, "You can't change the number of filters in linear_ if the parameter tensor has already been allocated."); num_outputs = num; } } unsigned long get_num_inputs() const { return num_inputs; } linear_bias_mode get_bias_mode() const { return bias_mode; } template void setup(const SUBNET& sub) { num_inputs = sub.get_output().nc(); if (bias_mode == LINEAR_HAS_BIAS) params.set_size(num_inputs + 1, num_outputs); else params.set_size(num_inputs, num_outputs); dlib::rand rnd(std::rand()); randomize_parameters(params, num_inputs + num_outputs, rnd); weights = alias_tensor(num_inputs, num_outputs); if (bias_mode == LINEAR_HAS_BIAS) { biases = alias_tensor(1, num_outputs); biases(params, weights.size()) = 0; } } template void forward(const SUBNET& sub, resizable_tensor& output) { const auto& prev_output = sub.get_output(); DLIB_CASSERT((long)num_inputs == prev_output.nc(), "The size of the input tensor to this linear layer doesn't match the size the linear layer was trained with."); output.set_size(prev_output.num_samples(), prev_output.k(), prev_output.nr(), num_outputs); auto o = alias_tensor(output.num_samples() * output.k() * output.nr(), num_outputs)(output, 0); auto so = alias_tensor(prev_output.num_samples() * prev_output.k() * prev_output.nr(), num_inputs)(prev_output, 0); auto w = weights(params, 0); tt::gemm(0, (tensor&)o, 1, so, false, w, false); if (bias_mode == LINEAR_HAS_BIAS) { auto b = biases(params, weights.size()); tt::add(1, (tensor&)o, 1, b); } } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) { auto gi = alias_tensor(gradient_input.num_samples() * gradient_input.k() * gradient_input.nr(), num_outputs)(gradient_input, 0); if (learning_rate_multiplier != 0) { const auto& prev_output = sub.get_output(); auto pw = weights(params_grad, 0); auto so = alias_tensor(prev_output.num_samples() * prev_output.k() * prev_output.nr(), num_inputs)(prev_output, 0); tt::gemm(0, pw, learning_rate_multiplier, so, true, gi, false); if (bias_mode == LINEAR_HAS_BIAS) { auto pb = biases(params_grad, weights.size()); tt::assign_bias_gradient(pb, gi); } } //prev_gradient is not const, so that sgi isn't const //since sgi is used as a destination for tt::gemm auto& prev_gradient = sub.get_gradient_input(); alias_tensor_instance sgi = alias_tensor(prev_gradient.num_samples() * prev_gradient.k() * prev_gradient.nr(), num_inputs)(prev_gradient, 0); auto w = weights(params, 0); tt::gemm(1, sgi, 1, gi, false, w, true); } alias_tensor_instance get_weights() { return weights(params, 0); } alias_tensor_const_instance get_weights() const { return weights(params, 0); } alias_tensor_instance get_biases() { static_assert(bias_mode == LINEAR_HAS_BIAS, "This linear_ layer doesn't have a bias vector " "to be retrieved, as per template parameter 'bias_mode'."); return biases(params, weights.size()); } alias_tensor_const_instance get_biases() const { static_assert(bias_mode == LINEAR_HAS_BIAS, "This linear_ layer doesn't have a bias vector " "to be retrieved, as per template parameter 'bias_mode'."); return biases(params, weights.size()); } inline dpoint map_input_to_output(const dpoint& p) const { return p; } inline dpoint map_output_to_input(const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const linear_& item, std::ostream& out) { serialize("linear_", out); serialize(item.num_outputs, out); serialize(item.num_inputs, out); serialize(item.params, out); serialize(item.weights, out); serialize(item.biases, out); serialize((int)item.bias_mode, out); serialize(item.learning_rate_multiplier, out); } friend void deserialize(linear_& item, std::istream& in) { std::string version; deserialize(version, in); if (version == "linear_") { deserialize(item.num_outputs, in); deserialize(item.num_inputs, in); deserialize(item.params, in); deserialize(item.weights, in); deserialize(item.biases, in); int bmode; deserialize(bmode, in); item.bias_mode = static_cast(bmode); if (bias_mode_ != item.bias_mode) throw serialization_error("Wrong bias_mode found while deserializing dlib::linear_"); deserialize(item.learning_rate_multiplier, in); } else { throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::linear_."); } } friend std::ostream& operator<<(std::ostream& out, const linear_& item) { out << "linear\t (num_outputs=" << item.num_outputs; if (item.bias_mode == LINEAR_HAS_BIAS) out << ", bias=true"; else out << ", bias=false"; out << ")"; out << " learning_rate_mult=" << item.learning_rate_multiplier; return out; } friend void to_xml(const linear_& item, std::ostream& out) { out << "\n"; out << mat(item.params); out << "\n"; } private: unsigned long num_outputs; unsigned long num_inputs; double learning_rate_multiplier; linear_bias_mode bias_mode; resizable_tensor params; alias_tensor weights, biases; }; template < unsigned long num_outputs, typename SUBNET > using linear = add_layer, SUBNET>; template < unsigned long num_outputs, typename SUBNET > using linear_no_bias = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- class dropout_ { public: explicit dropout_( float drop_rate_ = 0.5 ) : drop_rate(drop_rate_), rnd(std::rand()) { DLIB_CASSERT(0 <= drop_rate && drop_rate <= 1); } // We have to add a copy constructor and assignment operator because the rnd object // is non-copyable. dropout_( const dropout_& item ) : drop_rate(item.drop_rate), mask(item.mask), rnd(std::rand()) {} dropout_& operator= ( const dropout_& item ) { if (this == &item) return *this; drop_rate = item.drop_rate; mask = item.mask; return *this; } float get_drop_rate ( ) const { return drop_rate; } template void setup (const SUBNET& /*sub*/) { } void forward_inplace(const tensor& input, tensor& output) { // create a random mask and use it to filter the data mask.copy_size(input); rnd.fill_uniform(mask); tt::threshold(mask, drop_rate); tt::multiply(false, output, input, mask); } void backward_inplace( const tensor& gradient_input, tensor& data_grad, tensor& /*params_grad*/ ) { if (is_same_object(gradient_input, data_grad)) tt::multiply(false, data_grad, mask, gradient_input); else tt::multiply(true, data_grad, mask, gradient_input); } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const dropout_& item, std::ostream& out) { serialize("dropout_", out); serialize(item.drop_rate, out); serialize(item.mask, out); } friend void deserialize(dropout_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "dropout_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::dropout_."); deserialize(item.drop_rate, in); deserialize(item.mask, in); } void clean( ) { mask.clear(); } friend std::ostream& operator<<(std::ostream& out, const dropout_& item) { out << "dropout\t (" << "drop_rate="<\n"; } private: float drop_rate; resizable_tensor mask; tt::tensor_rand rnd; resizable_tensor params; // unused }; template using dropout = add_layer; // ---------------------------------------------------------------------------------------- template class dropout_rate_ : public dropout_ { public: explicit dropout_rate_() : dropout_(static_cast(DROP_RATE_PERCENT) / 100.0f) { static_assert(DROP_RATE_PERCENT >= 0 && DROP_RATE_PERCENT <= 100, "DROP_RATE_PERCENT must be between 0 and 100, inclusive."); } }; template using dropout_rate = add_layer, SUBNET>; template using dropout_10 = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- class multiply_ { public: explicit multiply_( float val_ = 0.5 ) : val(val_) { } multiply_ ( const dropout_& item ) : val(1-item.get_drop_rate()) {} float get_multiply_value ( ) const { return val; } template void setup (const SUBNET& /*sub*/) { } void forward_inplace(const tensor& input, tensor& output) { tt::affine_transform(output, input, val); } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } void backward_inplace( const tensor& gradient_input, tensor& data_grad, tensor& /*params_grad*/ ) { if (is_same_object(gradient_input, data_grad)) tt::affine_transform(data_grad, gradient_input, val); else tt::affine_transform(data_grad, data_grad, gradient_input, 1, val); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const multiply_& item, std::ostream& out) { serialize("multiply_", out); serialize(item.val, out); } friend void deserialize(multiply_& item, std::istream& in) { std::string version; deserialize(version, in); if (version == "dropout_") { // Since we can build a multiply_ from a dropout_ we check if that's what // is in the stream and if so then just convert it right here. unserialize sin(version, in); dropout_ temp; deserialize(temp, sin); item = temp; return; } if (version != "multiply_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::multiply_."); deserialize(item.val, in); } friend std::ostream& operator<<(std::ostream& out, const multiply_& item) { out << "multiply (" << "val="<\n"; } private: float val; resizable_tensor params; // unused }; template using multiply = add_layer; // ---------------------------------------------------------------------------------------- class affine_ { public: affine_( ) : mode(FC_MODE) { } affine_( layer_mode mode_ ) : mode(mode_) { } template < layer_mode bnmode > affine_( const bn_& item ) { gamma = item.gamma; beta = item.beta; mode = bnmode; params.copy_size(item.params); auto g = gamma(params,0); auto b = beta(params,gamma.size()); resizable_tensor temp(item.params); auto sg = gamma(temp,0); auto sb = beta(temp,gamma.size()); g = pointwise_divide(mat(sg), sqrt(mat(item.running_variances)+item.get_eps())); b = mat(sb) - pointwise_multiply(mat(g), mat(item.running_means)); } layer_mode get_mode() const { return mode; } void disable() { params.clear(); disabled = true; } bool is_disabled() const { return disabled; } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } template void setup (const SUBNET& sub) { if (disabled) return; if (mode == FC_MODE) { gamma = alias_tensor(1, sub.get_output().k(), sub.get_output().nr(), sub.get_output().nc()); } else { gamma = alias_tensor(1, sub.get_output().k()); } beta = gamma; params.set_size(gamma.size()+beta.size()); gamma(params,0) = 1; beta(params,gamma.size()) = 0; } void forward_inplace(const tensor& input, tensor& output) { if (disabled) return; auto g = gamma(params,0); auto b = beta(params,gamma.size()); if (mode == FC_MODE) tt::affine_transform(output, input, g, b); else tt::affine_transform_conv(output, input, g, b); } void backward_inplace( const tensor& gradient_input, tensor& data_grad, tensor& /*params_grad*/ ) { if (disabled) return; auto g = gamma(params,0); auto b = beta(params,gamma.size()); // We are computing the gradient of dot(gradient_input, computed_output*g + b) if (mode == FC_MODE) { if (is_same_object(gradient_input, data_grad)) tt::multiply(false, data_grad, gradient_input, g); else tt::multiply(true, data_grad, gradient_input, g); } else { if (is_same_object(gradient_input, data_grad)) tt::multiply_conv(false, data_grad, gradient_input, g); else tt::multiply_conv(true, data_grad, gradient_input, g); } } alias_tensor_instance get_gamma() { return gamma(params, 0); }; alias_tensor_const_instance get_gamma() const { return gamma(params, 0); }; alias_tensor_instance get_beta() { return beta(params, gamma.size()); }; alias_tensor_const_instance get_beta() const { return beta(params, gamma.size()); }; const tensor& get_layer_params() const { return empty_params; } tensor& get_layer_params() { return empty_params; } friend void serialize(const affine_& item, std::ostream& out) { serialize("affine_2", out); serialize(item.params, out); serialize(item.gamma, out); serialize(item.beta, out); serialize((int)item.mode, out); serialize(item.disabled, out); } friend void deserialize(affine_& item, std::istream& in) { std::string version; deserialize(version, in); if (version == "bn_con2") { // Since we can build an affine_ from a bn_ we check if that's what is in // the stream and if so then just convert it right here. unserialize sin(version, in); bn_ temp; deserialize(temp, sin); item = temp; return; } else if (version == "bn_fc2") { // Since we can build an affine_ from a bn_ we check if that's what is in // the stream and if so then just convert it right here. unserialize sin(version, in); bn_ temp; deserialize(temp, sin); item = temp; return; } if (version != "affine_" && version != "affine_2") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::affine_."); deserialize(item.params, in); deserialize(item.gamma, in); deserialize(item.beta, in); int mode; deserialize(mode, in); item.mode = (layer_mode)mode; if (version == "affine_2") deserialize(item.disabled, in); } friend std::ostream& operator<<(std::ostream& out, const affine_& item) { out << "affine"; if (item.disabled) out << "\t (disabled)"; return out; } friend void to_xml(const affine_& item, std::ostream& out) { if (item.mode==CONV_MODE) out << "\n"; out << mat(item.params); if (item.mode==CONV_MODE) out << "\n"; else out << "\n"; } private: resizable_tensor params, empty_params; alias_tensor gamma, beta; layer_mode mode; bool disabled = false; }; template using affine = add_layer; // ---------------------------------------------------------------------------------------- template < template class tag > class add_prev_ { public: const static unsigned long id = tag_id::id; add_prev_() { } template void setup (const SUBNET& /*sub*/) { } template void forward(const SUBNET& sub, resizable_tensor& output) { auto&& t1 = sub.get_output(); auto&& t2 = layer(sub).get_output(); output.set_size(std::max(t1.num_samples(),t2.num_samples()), std::max(t1.k(),t2.k()), std::max(t1.nr(),t2.nr()), std::max(t1.nc(),t2.nc())); tt::add(output, t1, t2); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { // The gradient just flows backwards to the two layers that forward() added // together. tt::add(sub.get_gradient_input(), sub.get_gradient_input(), gradient_input); tt::add(layer(sub).get_gradient_input(), layer(sub).get_gradient_input(), gradient_input); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } friend void serialize(const add_prev_& /*item*/, std::ostream& out) { serialize("add_prev_", out); } friend void deserialize(add_prev_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "add_prev_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::add_prev_."); } friend std::ostream& operator<<(std::ostream& out, const add_prev_& /*item*/) { out << "add_prev"<\n"; } private: resizable_tensor params; }; template < template class tag, typename SUBNET > using add_prev = add_layer, SUBNET>; template using add_prev1 = add_prev; template using add_prev2 = add_prev; template using add_prev3 = add_prev; template using add_prev4 = add_prev; template using add_prev5 = add_prev; template using add_prev6 = add_prev; template using add_prev7 = add_prev; template using add_prev8 = add_prev; template using add_prev9 = add_prev; template using add_prev10 = add_prev; using add_prev1_ = add_prev_; using add_prev2_ = add_prev_; using add_prev3_ = add_prev_; using add_prev4_ = add_prev_; using add_prev5_ = add_prev_; using add_prev6_ = add_prev_; using add_prev7_ = add_prev_; using add_prev8_ = add_prev_; using add_prev9_ = add_prev_; using add_prev10_ = add_prev_; // ---------------------------------------------------------------------------------------- template < template class tag > class mult_prev_ { public: const static unsigned long id = tag_id::id; mult_prev_() { } template void setup (const SUBNET& /*sub*/) { } template void forward(const SUBNET& sub, resizable_tensor& output) { auto&& t1 = sub.get_output(); auto&& t2 = layer(sub).get_output(); output.set_size(std::max(t1.num_samples(),t2.num_samples()), std::max(t1.k(),t2.k()), std::max(t1.nr(),t2.nr()), std::max(t1.nc(),t2.nc())); tt::multiply_zero_padded(false, output, t1, t2); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { auto&& t1 = sub.get_output(); auto&& t2 = layer(sub).get_output(); // The gradient just flows backwards to the two layers that forward() // multiplied together. tt::multiply_zero_padded(true, sub.get_gradient_input(), t2, gradient_input); tt::multiply_zero_padded(true, layer(sub).get_gradient_input(), t1, gradient_input); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } friend void serialize(const mult_prev_& /*item*/, std::ostream& out) { serialize("mult_prev_", out); } friend void deserialize(mult_prev_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "mult_prev_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::mult_prev_."); } friend std::ostream& operator<<(std::ostream& out, const mult_prev_& /*item*/) { out << "mult_prev"<\n"; } private: resizable_tensor params; }; template < template class tag, typename SUBNET > using mult_prev = add_layer, SUBNET>; template using mult_prev1 = mult_prev; template using mult_prev2 = mult_prev; template using mult_prev3 = mult_prev; template using mult_prev4 = mult_prev; template using mult_prev5 = mult_prev; template using mult_prev6 = mult_prev; template using mult_prev7 = mult_prev; template using mult_prev8 = mult_prev; template using mult_prev9 = mult_prev; template using mult_prev10 = mult_prev; using mult_prev1_ = mult_prev_; using mult_prev2_ = mult_prev_; using mult_prev3_ = mult_prev_; using mult_prev4_ = mult_prev_; using mult_prev5_ = mult_prev_; using mult_prev6_ = mult_prev_; using mult_prev7_ = mult_prev_; using mult_prev8_ = mult_prev_; using mult_prev9_ = mult_prev_; using mult_prev10_ = mult_prev_; // ---------------------------------------------------------------------------------------- template < template class tag > class multm_prev_ { public: const static unsigned long id = tag_id::id; multm_prev_() {} template void setup(const SUBNET& /*sub*/) {} template void forward(const SUBNET& sub, resizable_tensor& output) { auto& t1 = sub.get_output(); auto& t2 = layer(sub).get_output(); output.set_size(t1.num_samples(), t1.k(), t1.nr(), t2.nc()); tt::gemm(0, output, 1, t1, false, t2, false, operation_mode::PLANE_WISE); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { auto& t1 = sub.get_output(); auto& t2 = layer(sub).get_output(); auto& prev = sub.get_gradient_input(); auto& prev_tag = layer(sub).get_gradient_input(); tt::gemm(1, prev, 1, gradient_input, false, t2, true, operation_mode::PLANE_WISE); tt::gemm(1, prev_tag, 1, t1, true, gradient_input, false, operation_mode::PLANE_WISE); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } inline dpoint map_input_to_output(const dpoint& p) const { return p; } inline dpoint map_output_to_input(const dpoint& p) const { return p; } friend void serialize(const multm_prev_& /*item*/, std::ostream& out) { serialize("multm_prev_", out); } friend void deserialize(multm_prev_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "multm_prev_") throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::multm_prev_."); } friend std::ostream& operator<<(std::ostream& out, const multm_prev_& /*item*/) { out << "multm_prev" << id; return out; } friend void to_xml(const multm_prev_& /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; // unused }; template < template class tag, typename SUBNET > using multm_prev = add_layer, SUBNET>; template using multm_prev1 = multm_prev; template using multm_prev2 = multm_prev; template using multm_prev3 = multm_prev; template using multm_prev4 = multm_prev; template using multm_prev5 = multm_prev; template using multm_prev6 = multm_prev; template using multm_prev7 = multm_prev; template using multm_prev8 = multm_prev; template using multm_prev9 = multm_prev; template using multm_prev10 = multm_prev; using multm_prev1_ = multm_prev_; using multm_prev2_ = multm_prev_; using multm_prev3_ = multm_prev_; using multm_prev4_ = multm_prev_; using multm_prev5_ = multm_prev_; using multm_prev6_ = multm_prev_; using multm_prev7_ = multm_prev_; using multm_prev8_ = multm_prev_; using multm_prev9_ = multm_prev_; using multm_prev10_ = multm_prev_; // ---------------------------------------------------------------------------------------- template < template class tag > class resize_prev_to_tagged_ { public: const static unsigned long id = tag_id::id; resize_prev_to_tagged_() { } template void setup (const SUBNET& /*sub*/) { } template void forward(const SUBNET& sub, resizable_tensor& output) { auto& prev = sub.get_output(); auto& tagged = layer(sub).get_output(); DLIB_CASSERT(prev.num_samples() == tagged.num_samples()); output.set_size(prev.num_samples(), prev.k(), tagged.nr(), tagged.nc()); if (prev.nr() == tagged.nr() && prev.nc() == tagged.nc()) { tt::copy_tensor(false, output, 0, prev, 0, prev.k()); } else { tt::resize_bilinear(output, prev); } } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { auto& prev = sub.get_gradient_input(); DLIB_CASSERT(prev.k() == gradient_input.k()); DLIB_CASSERT(prev.num_samples() == gradient_input.num_samples()); if (prev.nr() == gradient_input.nr() && prev.nc() == gradient_input.nc()) { tt::copy_tensor(true, prev, 0, gradient_input, 0, prev.k()); } else { tt::resize_bilinear_gradient(prev, gradient_input); } } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } friend void serialize(const resize_prev_to_tagged_& /*item*/, std::ostream& out) { serialize("resize_prev_to_tagged_", out); } friend void deserialize(resize_prev_to_tagged_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "resize_prev_to_tagged_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::resize_prev_to_tagged_."); } friend std::ostream& operator<<(std::ostream& out, const resize_prev_to_tagged_& /*item*/) { out << "resize_prev_to_tagged"<\n"; } private: resizable_tensor params; }; template < template class tag, typename SUBNET > using resize_prev_to_tagged = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < template class tag > class scale_ { public: const static unsigned long id = tag_id::id; scale_() { } template void setup (const SUBNET& /*sub*/) { } template void forward(const SUBNET& sub, resizable_tensor& output) { auto&& scales = sub.get_output(); auto&& src = layer(sub).get_output(); DLIB_CASSERT(scales.num_samples() == src.num_samples() && scales.k() == src.k() && scales.nr() == 1 && scales.nc() == 1, "scales.k(): " << scales.k() << "\nsrc.k(): " << src.k() ); output.copy_size(src); tt::scale_channels(false, output, src, scales); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { auto&& scales = sub.get_output(); auto&& src = layer(sub).get_output(); // The gradient just flows backwards to the two layers that forward() // read from. tt::scale_channels(true, layer(sub).get_gradient_input(), gradient_input, scales); if (reshape_src.num_samples() != src.num_samples()) { reshape_scales = alias_tensor(src.num_samples()*src.k()); reshape_src = alias_tensor(src.num_samples()*src.k(),src.nr()*src.nc()); } auto&& scales_grad = sub.get_gradient_input(); auto sgrad = reshape_scales(scales_grad); tt::dot_prods(true, sgrad, reshape_src(src), reshape_src(gradient_input)); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const scale_& item, std::ostream& out) { serialize("scale_", out); serialize(item.reshape_scales, out); serialize(item.reshape_src, out); } friend void deserialize(scale_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "scale_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::scale_."); deserialize(item.reshape_scales, in); deserialize(item.reshape_src, in); } friend std::ostream& operator<<(std::ostream& out, const scale_& /*item*/) { out << "scale"<\n"; } private: alias_tensor reshape_scales; alias_tensor reshape_src; resizable_tensor params; }; template < template class tag, typename SUBNET > using scale = add_layer, SUBNET>; template using scale1 = scale; template using scale2 = scale; template using scale3 = scale; template using scale4 = scale; template using scale5 = scale; template using scale6 = scale; template using scale7 = scale; template using scale8 = scale; template using scale9 = scale; template using scale10 = scale; using scale1_ = scale_; using scale2_ = scale_; using scale3_ = scale_; using scale4_ = scale_; using scale5_ = scale_; using scale6_ = scale_; using scale7_ = scale_; using scale8_ = scale_; using scale9_ = scale_; using scale10_ = scale_; // ---------------------------------------------------------------------------------------- template < template class tag > class scale_prev_ { public: const static unsigned long id = tag_id::id; scale_prev_() { } template void setup (const SUBNET& /*sub*/) { } template void forward(const SUBNET& sub, resizable_tensor& output) { auto&& src = sub.get_output(); auto&& scales = layer(sub).get_output(); DLIB_CASSERT(scales.num_samples() == src.num_samples() && scales.k() == src.k() && scales.nr() == 1 && scales.nc() == 1, "scales.k(): " << scales.k() << "\nsrc.k(): " << src.k() ); output.copy_size(src); tt::scale_channels(false, output, src, scales); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { auto&& src = sub.get_output(); auto&& scales = layer(sub).get_output(); tt::scale_channels(true, sub.get_gradient_input(), gradient_input, scales); if (reshape_src.num_samples() != src.num_samples()) { reshape_scales = alias_tensor(src.num_samples()*src.k()); reshape_src = alias_tensor(src.num_samples()*src.k(),src.nr()*src.nc()); } auto&& scales_grad = layer(sub).get_gradient_input(); auto sgrad = reshape_scales(scales_grad); tt::dot_prods(true, sgrad, reshape_src(src), reshape_src(gradient_input)); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } friend void serialize(const scale_prev_& item, std::ostream& out) { serialize("scale_prev_", out); serialize(item.reshape_scales, out); serialize(item.reshape_src, out); } friend void deserialize(scale_prev_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "scale_prev_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::scale_prev_."); deserialize(item.reshape_scales, in); deserialize(item.reshape_src, in); } friend std::ostream& operator<<(std::ostream& out, const scale_prev_& /*item*/) { out << "scale_prev"<\n"; } private: alias_tensor reshape_scales; alias_tensor reshape_src; resizable_tensor params; }; template < template class tag, typename SUBNET > using scale_prev = add_layer, SUBNET>; template using scale_prev1 = scale_prev; template using scale_prev2 = scale_prev; template using scale_prev3 = scale_prev; template using scale_prev4 = scale_prev; template using scale_prev5 = scale_prev; template using scale_prev6 = scale_prev; template using scale_prev7 = scale_prev; template using scale_prev8 = scale_prev; template using scale_prev9 = scale_prev; template using scale_prev10 = scale_prev; using scale_prev1_ = scale_prev_; using scale_prev2_ = scale_prev_; using scale_prev3_ = scale_prev_; using scale_prev4_ = scale_prev_; using scale_prev5_ = scale_prev_; using scale_prev6_ = scale_prev_; using scale_prev7_ = scale_prev_; using scale_prev8_ = scale_prev_; using scale_prev9_ = scale_prev_; using scale_prev10_ = scale_prev_; // ---------------------------------------------------------------------------------------- class relu_ { public: relu_() { } void disable() { params.clear(); disabled = true; } bool is_disabled() const { return disabled; } template void setup (const SUBNET& /*sub*/) { } void forward_inplace(const tensor& input, tensor& output) { if (disabled) return; tt::relu(output, input); } void backward_inplace( const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& ) { if (disabled) return; tt::relu_gradient(data_grad, computed_output, gradient_input); } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const relu_& item, std::ostream& out) { serialize("relu_2", out); serialize(item.disabled, out); } friend void deserialize(relu_& item, std::istream& in) { std::string version; deserialize(version, in); if (version == "relu_2") { deserialize(item.disabled, in); return; } if (version != "relu_" && version != "relu_2") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::relu_."); } friend std::ostream& operator<<(std::ostream& out, const relu_& item) { out << "relu"; if (item.disabled) { out << "\t (disabled)"; } return out; } friend void to_xml(const relu_& item, std::ostream& out) { out << "\n"; } private: resizable_tensor params; bool disabled = false; }; template using relu = add_layer; // ---------------------------------------------------------------------------------------- class prelu_ { public: explicit prelu_( float initial_param_value_ = 0.25 ) : initial_param_value(initial_param_value_) { } float get_initial_param_value ( ) const { return initial_param_value; } template void setup (const SUBNET& /*sub*/) { params.set_size(1); params = initial_param_value; } template void forward( const SUBNET& sub, resizable_tensor& data_output ) { data_output.copy_size(sub.get_output()); tt::prelu(data_output, sub.get_output(), params); } template void backward( const tensor& gradient_input, SUBNET& sub, tensor& params_grad ) { tt::prelu_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input, params, params_grad); } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const prelu_& item, std::ostream& out) { serialize("prelu_", out); serialize(item.params, out); serialize(item.initial_param_value, out); } friend void deserialize(prelu_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "prelu_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::prelu_."); deserialize(item.params, in); deserialize(item.initial_param_value, in); } friend std::ostream& operator<<(std::ostream& out, const prelu_& item) { out << "prelu\t (" << "initial_param_value="<\n"; out << mat(item.params); out << "\n"; } private: resizable_tensor params; float initial_param_value; }; template using prelu = add_layer; // ---------------------------------------------------------------------------------------- class leaky_relu_ { public: explicit leaky_relu_( float alpha_ = 0.01f ) : alpha(alpha_) { } float get_alpha( ) const { return alpha; } template void setup(const SUBNET& /*sub*/) { } void forward_inplace(const tensor& input, tensor& output) { tt::leaky_relu(output, input, alpha); } void backward_inplace( const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& ) { tt::leaky_relu_gradient(data_grad, computed_output, gradient_input, alpha); } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const leaky_relu_& item, std::ostream& out) { serialize("leaky_relu_", out); serialize(item.alpha, out); } friend void deserialize(leaky_relu_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "leaky_relu_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::leaky_relu_."); deserialize(item.alpha, in); } friend std::ostream& operator<<(std::ostream& out, const leaky_relu_& item) { out << "leaky_relu\t(" << "alpha=" << item.alpha << ")"; return out; } friend void to_xml(const leaky_relu_& item, std::ostream& out) { out << "\n"; } private: resizable_tensor params; float alpha; }; template using leaky_relu = add_layer; // ---------------------------------------------------------------------------------------- class sig_ { public: sig_() { } template void setup (const SUBNET& /*sub*/) { } void forward_inplace(const tensor& input, tensor& output) { tt::sigmoid(output, input); } void backward_inplace( const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& ) { tt::sigmoid_gradient(data_grad, computed_output, gradient_input); } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const sig_& /*item*/, std::ostream& out) { serialize("sig_", out); } friend void deserialize(sig_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "sig_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::sig_."); } friend std::ostream& operator<<(std::ostream& out, const sig_& /*item*/) { out << "sig"; return out; } friend void to_xml(const sig_& /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; }; template using sig = add_layer; // ---------------------------------------------------------------------------------------- class mish_ { public: mish_() { } template void setup (const SUBNET& /*sub*/) { } template void forward( const SUBNET& sub, resizable_tensor& data_output ) { data_output.copy_size(sub.get_output()); tt::mish(data_output, sub.get_output()); } template void backward( const tensor& gradient_input, SUBNET& sub, tensor& ) { tt::mish_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input); } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const mish_& /*item*/, std::ostream& out) { serialize("mish_", out); } friend void deserialize(mish_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "mish_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::mish_."); } friend std::ostream& operator<<(std::ostream& out, const mish_& /*item*/) { out << "mish"; return out; } friend void to_xml(const mish_& /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; }; template using mish = add_layer; // ---------------------------------------------------------------------------------------- class htan_ { public: htan_() { } template void setup (const SUBNET& /*sub*/) { } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } void forward_inplace(const tensor& input, tensor& output) { tt::tanh(output, input); } void backward_inplace( const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& ) { tt::tanh_gradient(data_grad, computed_output, gradient_input); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const htan_& /*item*/, std::ostream& out) { serialize("htan_", out); } friend void deserialize(htan_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "htan_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::htan_."); } friend std::ostream& operator<<(std::ostream& out, const htan_& /*item*/) { out << "htan"; return out; } friend void to_xml(const htan_& /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; }; template using htan = add_layer; // ---------------------------------------------------------------------------------------- class clipped_relu_ { public: clipped_relu_( const float ceiling_ = 6.0f ) : ceiling(ceiling_) { } float get_ceiling( ) const { return ceiling; } template void setup (const SUBNET& /*sub*/) { } void forward_inplace(const tensor& input, tensor& output) { tt::clipped_relu(output, input, ceiling); } void backward_inplace( const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& ) { tt::clipped_relu_gradient(data_grad, computed_output, gradient_input, ceiling); } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const clipped_relu_& item, std::ostream& out) { serialize("clipped_relu_", out); serialize(item.ceiling, out); } friend void deserialize(clipped_relu_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "clipped_relu_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::clipped_relu_."); deserialize(item.ceiling, in); } friend std::ostream& operator<<(std::ostream& out, const clipped_relu_& item) { out << "clipped_relu\t(" << "ceiling=" << item.ceiling << ")"; return out; } friend void to_xml(const clipped_relu_& item, std::ostream& out) { out << "\n"; } private: resizable_tensor params; float ceiling; }; template using clipped_relu = add_layer; // ---------------------------------------------------------------------------------------- class elu_ { public: elu_( const float alpha_ = 1.0f ) : alpha(alpha_) { } float get_alpha( ) const { return alpha; } template void setup (const SUBNET& /*sub*/) { } void forward_inplace(const tensor& input, tensor& output) { tt::elu(output, input, alpha); } void backward_inplace( const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& ) { tt::elu_gradient(data_grad, computed_output, gradient_input, alpha); } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const elu_& item, std::ostream& out) { serialize("elu_", out); serialize(item.alpha, out); } friend void deserialize(elu_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "elu_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::elu_."); deserialize(item.alpha, in); } friend std::ostream& operator<<(std::ostream& out, const elu_& item) { out << "elu\t (" << "alpha=" << item.alpha << ")"; return out; } friend void to_xml(const elu_& item, std::ostream& out) { out << "\n"; } private: resizable_tensor params; float alpha; }; template using elu = add_layer; // ---------------------------------------------------------------------------------------- class gelu_ { public: gelu_() { } template void setup (const SUBNET& /*sub*/) { } template void forward( const SUBNET& sub, resizable_tensor& data_output ) { data_output.copy_size(sub.get_output()); tt::gelu(data_output, sub.get_output()); } template void backward( const tensor& gradient_input, SUBNET& sub, tensor& ) { tt::gelu_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input); } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const gelu_& /*item*/, std::ostream& out) { serialize("gelu_", out); } friend void deserialize(gelu_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "gelu_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::gelu_."); } friend std::ostream& operator<<(std::ostream& out, const gelu_& /*item*/) { out << "gelu"; return out; } friend void to_xml(const gelu_& /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; }; template using gelu = add_layer; // ---------------------------------------------------------------------------------------- class smelu_ { public: explicit smelu_( float beta_ = 1 ) : beta(beta_) { } float get_beta( ) const { return beta; } template void setup(const SUBNET& /*sub*/) { } void forward_inplace(const tensor& input, tensor& output) { tt::smelu(output, input, beta); } void backward_inplace( const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& ) { tt::smelu_gradient(data_grad, computed_output, gradient_input, beta); } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const smelu_& item, std::ostream& out) { serialize("smelu_", out); serialize(item.beta, out); } friend void deserialize(smelu_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "smelu_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::smelu_."); deserialize(item.beta, in); } friend std::ostream& operator<<(std::ostream& out, const smelu_& item) { out << "smelu\t (" << "beta=" << item.beta << ")"; return out; } friend void to_xml(const smelu_& item, std::ostream& out) { out << "\n"; } private: resizable_tensor params; float beta; }; template using smelu = add_layer; // ---------------------------------------------------------------------------------------- class silu_ { public: silu_( ) { } template void setup(const SUBNET& /*sub*/) { } template void forward( const SUBNET& sub, resizable_tensor& data_ouput) { data_ouput.copy_size(sub.get_output()); tt::silu(data_ouput, sub.get_output()); } template void backward( const tensor& gradient_input, SUBNET& sub, tensor& ) { tt::silu_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input); } inline dpoint map_input_to_output (const dpoint& p) const { return p; } inline dpoint map_output_to_input (const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const silu_& /*item*/, std::ostream& out) { serialize("silu_", out); } friend void deserialize(silu_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "silu_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::silu_."); } friend std::ostream& operator<<(std::ostream& out, const silu_& /*item*/) { out << "silu"; return out; } friend void to_xml(const silu_& /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; }; template using silu = add_layer; // ---------------------------------------------------------------------------------------- template class softmax_ { public: softmax_() {} template void setup(const SUBNET& /*sub*/) {} void forward_inplace(const tensor& input, tensor& output) { tt::softmax(output, input, s_mode_); } void backward_inplace( const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& /*params_grad*/ ) { tt::softmax_gradient(data_grad, computed_output, gradient_input, s_mode_); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const softmax_& /*item*/, std::ostream& out) { serialize("softmax_", out); } friend void deserialize(softmax_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "softmax_") throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::softmax_."); } friend std::ostream& operator<<(std::ostream& out, const softmax_& /*item*/) { out << "softmax (mode=" << (s_mode_ == operation_mode::CHANNEL_WISE ? "channel_wise" : "plane_wise") << ")"; return out; } friend void to_xml(const softmax_& /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; // unused }; template using softmax = add_layer, SUBNET>; template using softmaxm = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- class softmax_all_ { public: softmax_all_() { } template void setup (const SUBNET& /*sub*/) { } void forward_inplace(const tensor& input, tensor& output) { tt::softmax_all(output, input); } void backward_inplace( const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& ) { tt::softmax_all_gradient(data_grad, computed_output, gradient_input); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const softmax_all_& /*item*/, std::ostream& out) { serialize("softmax_all_", out); } friend void deserialize(softmax_all_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "softmax_all_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_all_."); } friend std::ostream& operator<<(std::ostream& out, const softmax_all_& /*item*/) { out << "softmax_all"; return out; } friend void to_xml(const softmax_all_& /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; }; template using softmax_all = add_layer; // ---------------------------------------------------------------------------------------- namespace impl { template class TAG_TYPE, template class... TAG_TYPES> struct concat_helper_impl{ constexpr static size_t tag_count() {return 1 + concat_helper_impl::tag_count();} static void list_tags(std::ostream& out) { out << tag_id::id << (tag_count() > 1 ? "," : ""); concat_helper_impl::list_tags(out); } template static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k) { auto& t = layer(sub).get_output(); concat_helper_impl::resize_out(out, sub, sum_k + t.k()); } template static void concat(tensor& out, const SUBNET& sub, size_t k_offset) { auto& t = layer(sub).get_output(); tt::copy_tensor(false, out, k_offset, t, 0, t.k()); k_offset += t.k(); concat_helper_impl::concat(out, sub, k_offset); } template static void split(const tensor& input, SUBNET& sub, size_t k_offset) { auto& t = layer(sub).get_gradient_input(); tt::copy_tensor(true, t, 0, input, k_offset, t.k()); k_offset += t.k(); concat_helper_impl::split(input, sub, k_offset); } }; template class TAG_TYPE> struct concat_helper_impl{ constexpr static size_t tag_count() {return 1;} static void list_tags(std::ostream& out) { out << tag_id::id; } template static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k) { auto& t = layer(sub).get_output(); out.set_size(t.num_samples(), t.k() + sum_k, t.nr(), t.nc()); } template static void concat(tensor& out, const SUBNET& sub, size_t k_offset) { auto& t = layer(sub).get_output(); tt::copy_tensor(false, out, k_offset, t, 0, t.k()); } template static void split(const tensor& input, SUBNET& sub, size_t k_offset) { auto& t = layer(sub).get_gradient_input(); tt::copy_tensor(true, t, 0, input, k_offset, t.k()); } }; } // concat layer template< template class... TAG_TYPES > class concat_ { static void list_tags(std::ostream& out) { impl::concat_helper_impl::list_tags(out);}; public: constexpr static size_t tag_count() {return impl::concat_helper_impl::tag_count();}; template void setup (const SUBNET&) { // do nothing } template void forward(const SUBNET& sub, resizable_tensor& output) { // the total depth of result is the sum of depths from all tags impl::concat_helper_impl::resize_out(output, sub, 0); // copy output from each tag into different part result impl::concat_helper_impl::concat(output, sub, 0); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor&) { // Gradient is split into parts for each tag layer impl::concat_helper_impl::split(gradient_input, sub, 0); } dpoint map_input_to_output(dpoint p) const { return p; } dpoint map_output_to_input(dpoint p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const concat_& /*item*/, std::ostream& out) { serialize("concat_", out); size_t count = tag_count(); serialize(count, out); } friend void deserialize(concat_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "concat_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::concat_."); size_t count_tags; deserialize(count_tags, in); if (count_tags != tag_count()) throw serialization_error("Invalid count of tags "+ std::to_string(count_tags) +", expecting " + std::to_string(tag_count()) + " found while deserializing dlib::concat_."); } friend std::ostream& operator<<(std::ostream& out, const concat_& /*item*/) { out << "concat\t ("; list_tags(out); out << ")"; return out; } friend void to_xml(const concat_& /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; // unused }; // concat layer definitions template class TAG1, template class TAG2, typename SUBNET> using concat2 = add_layer, SUBNET>; template class TAG1, template class TAG2, template class TAG3, typename SUBNET> using concat3 = add_layer, SUBNET>; template class TAG1, template class TAG2, template class TAG3, template class TAG4, typename SUBNET> using concat4 = add_layer, SUBNET>; template class TAG1, template class TAG2, template class TAG3, template class TAG4, template class TAG5, typename SUBNET> using concat5 = add_layer, SUBNET>; // inception layer will use tags internally. If user will use tags too, some conflicts // possible to exclude them, here are new tags specially for inceptions template using itag0 = add_tag_layer< 1000 + 0, SUBNET>; template using itag1 = add_tag_layer< 1000 + 1, SUBNET>; template using itag2 = add_tag_layer< 1000 + 2, SUBNET>; template using itag3 = add_tag_layer< 1000 + 3, SUBNET>; template using itag4 = add_tag_layer< 1000 + 4, SUBNET>; template using itag5 = add_tag_layer< 1000 + 5, SUBNET>; // skip to inception input template using iskip = add_skip_layer< itag0, SUBNET>; // here are some templates to be used for creating inception layer groups template class B1, templateclass B2, typename SUBNET> using inception2 = concat2>>>>>>; template class B1, templateclass B2, templateclass B3, typename SUBNET> using inception3 = concat3>>>>>>>>>; template class B1, templateclass B2, templateclass B3, templateclass B4, typename SUBNET> using inception4 = concat4>>>>>>>>>>>>; template class B1, templateclass B2, templateclass B3, templateclass B4, templateclass B5, typename SUBNET> using inception5 = concat5>>>>>>>>>>>>>>>; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- const double DEFAULT_L2_NORM_EPS = 1e-5; class l2normalize_ { public: explicit l2normalize_( double eps_ = DEFAULT_L2_NORM_EPS ) : eps(eps_) { } double get_eps() const { return eps; } template void setup (const SUBNET& /*sub*/) { } void forward_inplace(const tensor& input, tensor& output) { tt::inverse_norms(norm, input, eps); tt::scale_rows(output, input, norm); } void backward_inplace( const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& /*params_grad*/ ) { if (is_same_object(gradient_input, data_grad)) { tt::dot_prods(temp, gradient_input, computed_output); tt::scale_rows2(0, data_grad, gradient_input, computed_output, temp, norm); } else { tt::dot_prods(temp, gradient_input, computed_output); tt::scale_rows2(1, data_grad, gradient_input, computed_output, temp, norm); } } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const l2normalize_& item, std::ostream& out) { serialize("l2normalize_", out); serialize(item.eps, out); } friend void deserialize(l2normalize_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "l2normalize_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::l2normalize_."); deserialize(item.eps, in); } friend std::ostream& operator<<(std::ostream& out, const l2normalize_& item) { out << "l2normalize"; out << " eps="<\n"; } private: double eps; resizable_tensor params; // unused // Here only to avoid reallocation and as a cache between forward/backward // functions. resizable_tensor norm; resizable_tensor temp; }; template using l2normalize = add_layer; // ---------------------------------------------------------------------------------------- template < long _offset, long _k, long _nr, long _nc > class extract_ { static_assert(_offset >= 0, "The offset must be >= 0."); static_assert(_k > 0, "The number of channels must be > 0."); static_assert(_nr > 0, "The number of rows must be > 0."); static_assert(_nc > 0, "The number of columns must be > 0."); public: extract_( ) { } template void setup (const SUBNET& sub) { DLIB_CASSERT((long)sub.get_output().size() >= sub.get_output().num_samples()*(_offset+_k*_nr*_nc), "The tensor we are trying to extract from the input tensor is too big to fit into the input tensor."); aout = alias_tensor(sub.get_output().num_samples(), _k*_nr*_nc); ain = alias_tensor(sub.get_output().num_samples(), sub.get_output().size()/sub.get_output().num_samples()); } template void forward(const SUBNET& sub, resizable_tensor& output) { if (aout.num_samples() != sub.get_output().num_samples()) { aout = alias_tensor(sub.get_output().num_samples(), _k*_nr*_nc); ain = alias_tensor(sub.get_output().num_samples(), sub.get_output().size()/sub.get_output().num_samples()); } output.set_size(sub.get_output().num_samples(), _k, _nr, _nc); auto out = aout(output,0); auto in = ain(sub.get_output(),0); tt::copy_tensor(false, out, 0, in, _offset, _k*_nr*_nc); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { auto out = ain(sub.get_gradient_input(),0); auto in = aout(gradient_input,0); tt::copy_tensor(true, out, _offset, in, 0, _k*_nr*_nc); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const extract_& /*item*/, std::ostream& out) { serialize("extract_", out); serialize(_offset, out); serialize(_k, out); serialize(_nr, out); serialize(_nc, out); } friend void deserialize(extract_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "extract_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::extract_."); long offset; long k; long nr; long nc; deserialize(offset, in); deserialize(k, in); deserialize(nr, in); deserialize(nc, in); if (offset != _offset) throw serialization_error("Wrong offset found while deserializing dlib::extract_"); if (k != _k) throw serialization_error("Wrong k found while deserializing dlib::extract_"); if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::extract_"); if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::extract_"); } friend std::ostream& operator<<(std::ostream& out, const extract_& /*item*/) { out << "extract\t (" << "offset="<<_offset << ", k="<<_k << ", nr="<<_nr << ", nc="<<_nc << ")"; return out; } friend void to_xml(const extract_& /*item*/, std::ostream& out) { out << "\n"; } private: alias_tensor aout, ain; resizable_tensor params; // unused }; template < long offset, long k, long nr, long nc, typename SUBNET > using extract = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < long _offset_k, long _offset_nr, long _offset_nc, long _k, long _nr, long _nc > class slice_ { static_assert(_offset_k >= 0, "The channel offset must be >= 0."); static_assert(_offset_nr >= 0, "The row offset must be >= 0."); static_assert(_offset_nc >= 0, "The column offset must be >= 0."); static_assert(_k > 0, "The number of channels must be > 0."); static_assert(_nr > 0, "The number of rows must be > 0."); static_assert(_nc > 0, "The number of columns must be > 0."); public: slice_( ) { } template void setup (const SUBNET& sub) { DLIB_CASSERT((long)sub.get_output().size() >= sub.get_output().num_samples()*(_offset_k+_offset_nr+_offset_nc+_k*_nr*_nc), "The tensor we are trying to slice from the input tensor is too big to fit into the input tensor."); } template void forward(const SUBNET& sub, resizable_tensor& output) { output.set_size(sub.get_output().num_samples(), _k, _nr, _nc); tt::copy_tensor(false, output, 0, 0, 0, sub.get_output(), _offset_k, _offset_nr, _offset_nc, _k, _nr, _nc); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { tt::copy_tensor(true, sub.get_gradient_input(), _offset_k, _offset_nr, _offset_nc, gradient_input, 0, 0, 0, _k, _nr, _nc); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const slice_& /*item*/, std::ostream& out) { serialize("slice_", out); serialize(_offset_k, out); serialize(_offset_nr, out); serialize(_offset_nc, out); serialize(_k, out); serialize(_nr, out); serialize(_nc, out); } friend void deserialize(slice_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "slice_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::slice_."); long offset_k; long offset_nr; long offset_nc; long k; long nr; long nc; deserialize(offset_k, in); deserialize(offset_nr, in); deserialize(offset_nc, in); deserialize(k, in); deserialize(nr, in); deserialize(nc, in); if (offset_k != _offset_k) throw serialization_error("Wrong offset_k found while deserializing dlib::slice_"); if (offset_nr != _offset_nr) throw serialization_error("Wrong offset_nr found while deserializing dlib::slice_"); if (offset_nc != _offset_nc) throw serialization_error("Wrong offset_nc found while deserializing dlib::slice_"); if (k != _k) throw serialization_error("Wrong k found while deserializing dlib::slice_"); if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::slice_"); if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::slice_"); } friend std::ostream& operator<<(std::ostream& out, const slice_& /*item*/) { out << "slice\t (" << "offset_k="<<_offset_k << "offset_nr="<<_offset_nr << "offset_nc="<<_offset_nc << ", k="<<_k << ", nr="<<_nr << ", nc="<<_nc << ")"; return out; } friend void to_xml(const slice_& /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; // unused }; template < long offset_k, long offset_nr, long offset_nc, long k, long nr, long nc, typename SUBNET > using slice = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template class reorg_ { static_assert(row_stride >= 1, "The row_stride must be >= 1"); static_assert(row_stride >= 1, "The col_stride must be >= 1"); public: reorg_( ) { } template void setup (const SUBNET& sub) { DLIB_CASSERT(sub.get_output().nr() % row_stride == 0); DLIB_CASSERT(sub.get_output().nc() % col_stride == 0); } template void forward(const SUBNET& sub, resizable_tensor& output) { output.set_size( sub.get_output().num_samples(), sub.get_output().k() * col_stride * row_stride, sub.get_output().nr() / row_stride, sub.get_output().nc() / col_stride ); tt::reorg(false, output, row_stride, col_stride, sub.get_output()); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { tt::reorg_gradient(true, sub.get_gradient_input(), row_stride, col_stride, gradient_input); } inline dpoint map_input_to_output (dpoint p) const { p.x() = p.x() / col_stride; p.y() = p.y() / row_stride; return p; } inline dpoint map_output_to_input (dpoint p) const { p.x() = p.x() * col_stride; p.y() = p.y() * row_stride; return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const reorg_& /*item*/, std::ostream& out) { serialize("reorg_", out); serialize(row_stride, out); serialize(col_stride, out); } friend void deserialize(reorg_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "reorg_") throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::reorg_."); long long rs; long long cs; deserialize(rs, in); deserialize(cs, in); if (rs != row_stride) throw serialization_error("Wrong row_stride found while deserializing dlib::reorg_"); if (cs != col_stride) throw serialization_error("Wrong col_stride found while deserializing dlib::reorg_"); } friend std::ostream& operator<<(std::ostream& out, const reorg_& /*item*/) { out << "reorg\t (" << "row_stride=" << row_stride << ", col_stride=" << col_stride << ")"; return out; } friend void to_xml(const reorg_ /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; // unused }; template using reorg = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- class transpose_ { public: transpose_() {} template void setup(const SUBNET& /*sub*/) {} template void forward(const SUBNET& sub, resizable_tensor& output) { auto& prev = sub.get_output(); output.set_size(prev.num_samples(), prev.k(), prev.nc(), prev.nr()); tt::transpose(false, output, prev); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { auto& prev = sub.get_gradient_input(); tt::transpose(true, prev, gradient_input); } inline dpoint map_input_to_output(dpoint p) const { dpoint temp_p; temp_p.x() = p.y(); temp_p.y() = p.x(); return temp_p; } inline dpoint map_output_to_input(dpoint p) const { dpoint temp_p; temp_p.x() = p.y(); temp_p.y() = p.x(); return temp_p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const transpose_& /*item*/, std::ostream& out) { serialize("transpose_", out); } friend void deserialize(transpose_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "transpose_") throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::transpose_."); } friend std::ostream& operator<<(std::ostream& out, const transpose_& /*item*/) { out << "transpose"; return out; } friend void to_xml(const transpose_& /*item*/, std::ostream& out) { out << "\n"; } private: dlib::resizable_tensor params; // unused }; template using transpose = add_layer; // ---------------------------------------------------------------------------------------- class positional_encodings_ { public: positional_encodings_(unsigned long sequence_dim_ = 1, unsigned long embedding_dim_ = 1) : sequence_dim(sequence_dim_), embedding_dim(embedding_dim_) { } positional_encodings_(const positional_encodings_& item) : pe(item.pe), sequence_dim(item.sequence_dim), embedding_dim(item.embedding_dim) { } positional_encodings_& operator= (const positional_encodings_& item) { if (this == &item) return *this; pe = item.pe; sequence_dim = item.sequence_dim; embedding_dim = item.embedding_dim; return *this; } template void setup(const SUBNET& sub) { auto& prev = sub.get_output(); sequence_dim = prev.nr(); embedding_dim = prev.nc(); const unsigned long ns = prev.num_samples(); const unsigned long nk = prev.k(); const float n = 10000.0f; pe.set_size(ns, nk, sequence_dim, embedding_dim); for (unsigned long s = 0; s < ns; ++s) { for (unsigned long k = 0; k < nk; ++k) { for (unsigned long r = 0; r < sequence_dim; ++r) { for (unsigned long c = 0; c < embedding_dim; ++c) { float theta = static_cast(r) / std::pow(n, static_cast(c) / embedding_dim); if (c % 2 == 0) pe.host()[tensor_index(pe, s, k, r, c)] = std::sin(theta); else pe.host()[tensor_index(pe, s, k, r, c)] = std::cos(theta); } } } } } template void forward(const SUBNET& sub, resizable_tensor& output) { const auto& prev_output = sub.get_output(); if (!have_same_dimensions(pe, prev_output)) setup(sub); output.set_size(prev_output.num_samples(), prev_output.k(), sequence_dim, embedding_dim); tt::add(output, prev_output, pe); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { auto& prev_grad = sub.get_gradient_input(); tt::add(prev_grad, prev_grad, gradient_input); } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } const tensor& get_positional_encodings() const { return pe; } tensor& get_positional_encodings() { return pe; } friend void serialize(const positional_encodings_& /*item*/, std::ostream& out) { serialize("positional_encodings_", out); } friend void deserialize(positional_encodings_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "positional_encodings_") throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::positional_encodings_."); } friend std::ostream& operator<<(std::ostream& out, const positional_encodings_& /*item*/) { out << "positional_encodings"; return out; } friend void to_xml(const positional_encodings_& /*item*/, std::ostream& out) { out << "\n"; } private: resizable_tensor params; // unused resizable_tensor pe; unsigned long sequence_dim, embedding_dim; }; template using positional_encodings = add_layer; // ---------------------------------------------------------------------------------------- template< unsigned long num_embeddings_, unsigned long embedding_dim_ > class embeddings_ { static_assert(num_embeddings_ > 0, "The size of the embedding dictionary must be > 0"); static_assert(embedding_dim_ > 0, "The size of each embedding vector must be > 0"); public: embeddings_() : num_embeddings(num_embeddings_), embedding_dim(embedding_dim_), learning_rate_multiplier(1.0f), scale_by_freq(true) { } double get_learning_rate_multiplier() const { return learning_rate_multiplier; } void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } void set_scale_by_freq(bool val) { scale_by_freq = val; } bool get_scale_by_freq() const { return scale_by_freq; } unsigned long get_num_embeddings() const { return num_embeddings; } void set_num_embeddings(unsigned long num) { DLIB_CASSERT(num > 0); if (num != num_embeddings) { DLIB_CASSERT(get_embeddings().size() == 0, "It is not possible to change the size of the embedding dictionary if the parameter has already been assigned."); } } unsigned long get_embedding_dim() const { return embedding_dim; } void set_embedding_dim(unsigned long dim) { DLIB_CASSERT(dim > 0); if (dim != embedding_dim) { DLIB_CASSERT(get_embeddings().size() == 0, "It is not possible to change the size of the embedding dictionary if the parameter has already been assigned."); } } template void setup(const SUBNET& /*sub*/) { embs.set_size(num_embeddings, embedding_dim); tt::tensor_rand rnd(std::rand()); rnd.fill_gaussian(embs); } template void forward(const SUBNET& sub, resizable_tensor& output) { const auto& prev = sub.get_output(); output.set_size(prev.num_samples(), prev.k(), prev.nr(), embedding_dim); tt::embeddings(output, prev, embs); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { // Because this class is expected to be directly after an layer, // it's not necessary to propagate the gradient. // Additionally, this layer is treated as constant during backpropagation, // so it technically doesn't contribute to the gradient computation. if (learning_rate_multiplier != 0) { auto& prev_src = sub.get_output(); calc_token_freqs(prev_src, gradient_input); tt::embeddings_gradient(prev_src, gradient_input, embs, freqs, learning_rate_multiplier, scale_by_freq); } } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } const tensor& get_embeddings() const { return embs; } tensor& get_embeddings() { return embs; } friend void serialize(const embeddings_& item, std::ostream& out) { serialize("embeddings_", out); serialize(item.embs, out); serialize(item.num_embeddings, out); serialize(item.embedding_dim, out); serialize(item.learning_rate_multiplier, out); serialize(item.scale_by_freq, out); } friend void deserialize(embeddings_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "embeddings_") throw serialization_error("Unexpected version found while deserializing dlib::embeddings_."); deserialize(item.embs, in); deserialize(item.num_embeddings, in); deserialize(item.embedding_dim, in); deserialize(item.learning_rate_multiplier, in); deserialize(item.scale_by_freq, in); } friend std::ostream& operator<<(std::ostream& out, const embeddings_& item) { out << "embeddings (num_embeddings=" << item.num_embeddings << ", embedding_dim=" << item.embedding_dim << ") learning_rate_mult=" << item.learning_rate_multiplier; return out; } friend void to_xml(const embeddings_& item, std::ostream& out) { out << "\n"; out << mat(item.embs); out << "\n"; } private: void calc_token_freqs(const tensor& prev, const tensor& input) { if (freqs.size() == 0) freqs.set_size(num_embeddings, 1, 1, 1); freqs = 0; const float* prev_data = prev.host(); float* freqs_data = freqs.host(); for (long s = 0; s < input.num_samples(); ++s) { for (long k = 0; k < input.k(); ++k) { for (long r = 0; r < input.nr(); ++r) { const unsigned long token_idx = static_cast(prev_data[tensor_index(prev, s, k, r, 0)]); if (token_idx < num_embeddings) freqs_data[tensor_index(freqs, token_idx, 0, 0, 0)]++; } } } } resizable_tensor params; // unused resizable_tensor embs, freqs; unsigned long num_embeddings, embedding_dim; double learning_rate_multiplier; bool scale_by_freq; }; template < unsigned long nb_embeddings, unsigned long embedding_length, typename SUBNET > using embeddings = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- struct neg_infinity_tag {}; struct zero_tag {}; template struct is_special_value : std::false_type {}; template<> struct is_special_value : std::true_type {}; template<> struct is_special_value : std::true_type {}; template class tril_ { public: tril_(): diag(diag_), diag_value(compute_diag_value()) {} template void setup(const SUBNET& /*sub*/) { } template void forward(const SUBNET& sub, resizable_tensor& output) { auto& prev = sub.get_output(); output.set_size(prev.num_samples(), prev.k(), prev.nr(), prev.nc()); check_mask(prev); tt::multiply(false, output, prev, binary_mask); if (diag_value != 0.0f) tt::add(1, output, 1, output_mask); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { auto& prev_grad = sub.get_gradient_input(); tt::multiply(true, prev_grad, gradient_input, binary_mask); } inline dpoint map_input_to_output(const dpoint& p) const { return p; } inline dpoint map_output_to_input(const dpoint& p) const { return p; } const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } friend void serialize(const tril_& item, std::ostream& out) { serialize("tril_", out); serialize(item.diag, out); serialize(item.diag_value, out); } friend void deserialize(tril_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "tril_") throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::tril_."); deserialize(item.diag, in); deserialize(item.diag_value, in); } friend std::ostream& operator<<(std::ostream& out, const tril_& item) { out << "tril (diag=" << item.diag << ", diag_value=" << item.diag_value << ")"; return out; } friend void to_xml(const tril_& item, std::ostream& out) { out << "\n"; } private: float compute_diag_value() const { if (std::is_same::value) return -std::numeric_limits::infinity(); else if (std::is_same::value) return 0.0f; else return static_cast(num_) / static_cast(den_); } void check_mask(const tensor& t) { if (!have_same_dimensions(binary_mask, t)) { binary_mask.copy_size(t); binary_mask = 1; if (diag_value != 0.0f) { output_mask.copy_size(t); output_mask = 0; } for (long s = 0; s < output_mask.num_samples(); ++s) { for (long k = 0; k < output_mask.k(); ++k) { for (long r = 0; r < output_mask.nr(); ++r) { for (long c = std::max(r + diag + 1, 0L); c < output_mask.nc(); ++c) { if (diag_value != 0.0f) output_mask.host()[tensor_index(output_mask, s, k, r, c)] = diag_value; binary_mask.host()[tensor_index(binary_mask, s, k, r, c)] = 0; } } } } } } template struct always_false : std::false_type {}; resizable_tensor params; // unused resizable_tensor binary_mask, output_mask; long diag; float diag_value; }; template using tril = add_layer, SUBNET>; template using tril_mask = add_layer, SUBNET>; template using tril_diag = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template class adaptive_computation_time_ { public: explicit adaptive_computation_time_() : max_steps_(max_steps), halt_threshold_(0.99f), // theta in Graves' notation ponder_penalty_(0.01f), // lambda (ponder cost weight) enable_depth_scaling_(false), batch_size_(0), seq_len_(0), d_model_(0), num_channels_(0), feature_dim_(0), ponder_cost_(0), avg_steps_(0) { } adaptive_computation_time_(const adaptive_computation_time_& item) : max_steps_(item.max_steps_), halt_threshold_(item.halt_threshold_), ponder_penalty_(item.ponder_penalty_), enable_depth_scaling_(item.enable_depth_scaling_), batch_size_(item.batch_size_), seq_len_(item.seq_len_), d_model_(item.d_model_), num_channels_(item.num_channels_), feature_dim_(item.feature_dim_), ponder_cost_(item.ponder_cost_), avg_steps_(item.avg_steps_), params(item.params), halting_probs_(item.halting_probs_), cumulative_halting_(item.cumulative_halting_), remainders_(item.remainders_), n_steps_(item.n_steps_), logits_(item.logits_), grad_logits_(item.grad_logits_), input_cache_(item.input_cache_), true_effective_weights_(item.true_effective_weights_) { } adaptive_computation_time_& operator=(const adaptive_computation_time_& item) { if (this == &item) return *this; max_steps_ = item.max_steps_; halt_threshold_ = item.halt_threshold_; ponder_penalty_ = item.ponder_penalty_; enable_depth_scaling_ = item.enable_depth_scaling_; batch_size_ = item.batch_size_; seq_len_ = item.seq_len_; d_model_ = item.d_model_; num_channels_ = item.num_channels_; feature_dim_ = item.feature_dim_; ponder_cost_ = item.ponder_cost_; avg_steps_ = item.avg_steps_; params = item.params; halting_probs_ = item.halting_probs_; cumulative_halting_ = item.cumulative_halting_; remainders_ = item.remainders_; n_steps_ = item.n_steps_; logits_ = item.logits_; grad_logits_ = item.grad_logits_; input_cache_ = item.input_cache_; true_effective_weights_ = item.true_effective_weights_; return *this; } template void setup(const SUBNET& sub) { const auto& input = sub.get_output(); // Store expected dimensions for parameter initialization batch_size_ = input.num_samples(); seq_len_ = input.nr(); d_model_ = input.nc(); num_channels_ = input.k(); feature_dim_ = d_model_ * num_channels_; // Initialize halting parameters params.set_size(1, 1, feature_dim_ + 1, 1); // He initialization for stability dlib::rand rnd(std::rand()); const float scale = std::sqrt(2.0f / feature_dim_); float* p = params.host(); // Initialize weight matrix W_halt for (long i = 0; i < feature_dim_; ++i) p[i] = rnd.get_random_gaussian() * scale; // Initialize bias b_halt (typically zero) p[feature_dim_] = 0.0f; // Pre-allocate workspace for maximum expected size allocate_workspace(); } template void forward(const SUBNET& sub, resizable_tensor& output) { const tensor& input = sub.get_output(); output.copy_size(input); // Ensure workspace is allocated for current batch dimensions const long curr_batch = input.num_samples(); if (curr_batch != batch_size_) { batch_size_ = curr_batch; allocate_workspace(); } // Initialize output for weighted accumulation output = 0; // Initialize ACT state vectors const long total_positions = batch_size_ * seq_len_; float* cum_halt_ptr = cumulative_halting_.host(); float* remainders_ptr = remainders_.host(); float* n_steps_ptr = n_steps_.host(); for (long i = 0; i < total_positions; ++i) { cum_halt_ptr[i] = 0.0f; // h_t^n: cumulative halting probability remainders_ptr[i] = 1.0f; // ρ_t: remaining probability mass n_steps_ptr[i] = 0.0f; // N(t): number of computation steps } // Cache input for backward pass input_cache_.copy_size(input); tt::copy_tensor(false, input_cache_, 0, input, 0, input.k()); // Initialize effective weights tracker for gradient computation true_effective_weights_.set_size(total_positions, 1, 1, 1); true_effective_weights_ = 0; // Main ACT computation loop for (long step = 0; step < max_steps_; ++step) { // Compute halting probabilities: p_t^n = sigmoid(W_halt^T * s_t^n + b_halt) tt::compute_act_halt_probabilities( halting_probs_, logits_, input, params, batch_size_, seq_len_, feature_dim_); // Update ACT state and accumulate weighted outputs tt::update_act_state( output, input, halting_probs_, cumulative_halting_, remainders_, n_steps_, true_effective_weights_, batch_size_, seq_len_, d_model_, num_channels_, halt_threshold_, step ); // Early termination optimization if (all_positions_halted(cumulative_halting_)) break; } // Finalize with remainder contributions tt::finalize_act_output( output, input, remainders_, true_effective_weights_, batch_size_, seq_len_, d_model_, num_channels_); // Compute statistics for monitoring and regularization compute_ponder_stats(); } template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) { tensor& input_grad = sub.get_gradient_input(); const float* grad_in = gradient_input.host(); const float* eff_weights = true_effective_weights_.host(); float* grad_out = input_grad.host(); for (long n = 0; n < batch_size_; ++n) { for (long s = 0; s < seq_len_; ++s) { const long pos = n * seq_len_ + s; const float weight = eff_weights[pos]; for (long c = 0; c < num_channels_; ++c) { for (long d = 0; d < d_model_; ++d) { const long idx = ((n * num_channels_ + c) * seq_len_ + s) * d_model_ + d; grad_out[idx] += weight * grad_in[idx]; } } } } // Compute parameter gradients from ponder cost regularization params_grad = 0; // Optional: Apply depth-dependent gradient scaling if (enable_depth_scaling_) { tt::apply_act_depth_scaling( input_grad, n_steps_, batch_size_, seq_len_, d_model_, num_channels_, static_cast(max_steps_), 0.1f ); } } // Accessor methods const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } long get_max_steps() const { return max_steps_; } float get_halt_threshold() const { return halt_threshold_; } float get_ponder_penalty() const { return ponder_penalty_; } void set_halt_threshold(float threshold) { if (threshold > 0 && threshold <= 1.0f) halt_threshold_ = threshold; } void set_ponder_penalty(float penalty) { if (penalty >= 0) ponder_penalty_ = penalty; } // Statistics for monitoring and regularization float get_ponder_cost() const { return ponder_cost_; } // R(x) float get_average_steps() const { return avg_steps_; } // Average N(t) // Depth scaling control void enable_depth_scaling() { enable_depth_scaling_ = true; } void disable_depth_scaling() { enable_depth_scaling_ = false; } bool depth_scaling_enabled() const { return enable_depth_scaling_; } inline dpoint map_input_to_output(const dpoint& p) const { return p; } inline dpoint map_output_to_input(const dpoint& p) const { return p; } // Serialization methods friend void serialize(const adaptive_computation_time_& item, std::ostream& out) { dlib::serialize("act_", out); dlib::serialize(item.max_steps_, out); dlib::serialize(item.halt_threshold_, out); dlib::serialize(item.ponder_penalty_, out); dlib::serialize(item.enable_depth_scaling_, out); dlib::serialize(item.batch_size_, out); dlib::serialize(item.seq_len_, out); dlib::serialize(item.d_model_, out); dlib::serialize(item.num_channels_, out); dlib::serialize(item.feature_dim_, out); dlib::serialize(item.params, out); } friend void deserialize(adaptive_computation_time_& item, std::istream& in) { std::string version; dlib::deserialize(version, in); if (version != "act_") throw serialization_error("Unexpected version: " + version); dlib::deserialize(item.max_steps_, in); dlib::deserialize(item.halt_threshold_, in); dlib::deserialize(item.ponder_penalty_, in); dlib::deserialize(item.enable_depth_scaling_, in); dlib::deserialize(item.batch_size_, in); dlib::deserialize(item.seq_len_, in); dlib::deserialize(item.d_model_, in); dlib::deserialize(item.num_channels_, in); dlib::deserialize(item.feature_dim_, in); dlib::deserialize(item.params, in); item.allocate_workspace(); } friend std::ostream& operator<<(std::ostream& out, const adaptive_computation_time_& item) { out << "act (steps=" << item.max_steps_ << ", dim=" << item.feature_dim_ << ", threshold=" << item.halt_threshold_ << ", penalty=" << item.ponder_penalty_ << ")"; return out; } friend void to_xml(const adaptive_computation_time_& item, std::ostream& out) { out << "\n"; out << mat(item.params); out << "\n"; } private: void allocate_workspace() { const long total_positions = batch_size_ * seq_len_; // Allocate state tensors for maximum expected size // These track the ACT state for each position (batch, sequence) halting_probs_.set_size(total_positions, 1, 1, 1); // p_t^n cumulative_halting_.set_size(total_positions, 1, 1, 1); // h_t^n remainders_.set_size(total_positions, 1, 1, 1); // rho_t n_steps_.set_size(total_positions, 1, 1, 1); // N(t) logits_.set_size(total_positions, 1, 1, 1); // logits before sigmoid grad_logits_.set_size(total_positions, 1, 1, 1); // gradient w.r.t. logits true_effective_weights_.set_size(total_positions, 1, 1, 1); // Input cache needs full dimensions for gradient computation input_cache_.set_size(batch_size_, num_channels_, seq_len_, d_model_); } bool all_positions_halted(const resizable_tensor& ch) const { const float* cum_halt = ch.host(); const long total = batch_size_ * seq_len_; for (long i = 0; i < total; ++i) { if (cum_halt[i] < halt_threshold_) return false; } return true; } void compute_ponder_stats() { const float* steps = n_steps_.host(); const long total = batch_size_ * seq_len_; // Compute average number of steps: (1/T) * SUM N(t) float sum_steps = 0; for (long i = 0; i < total; ++i) sum_steps += steps[i]; avg_steps_ = sum_steps / total; // Normalize ponder cost by maximum possible steps ponder_cost_ = avg_steps_ / max_steps_; } // Configuration parameters long max_steps_; // Maximum computation steps per position float halt_threshold_; // theta: Halting threshold (typically 0.99) float ponder_penalty_; // lambda: Ponder cost weight for regularization bool enable_depth_scaling_; // Enable depth-dependent gradient scaling // Dimension tracking long batch_size_; long seq_len_; long d_model_; long num_channels_; long feature_dim_; // Learnable parameters resizable_tensor params; // Working memory resizable_tensor halting_probs_; // p_t^n: Halting probabilities resizable_tensor cumulative_halting_; // h_t^n: Cumulative halting probabilities resizable_tensor remainders_; // rho_t: Remaining probability mass resizable_tensor n_steps_; // N(t): Number of steps taken resizable_tensor logits_; // Raw logits before sigmoid resizable_tensor grad_logits_; // Gradients w.r.t. logits resizable_tensor input_cache_; // Cached input for backward pass resizable_tensor true_effective_weights_; // Statistics for monitoring float ponder_cost_; // R(x): Current ponder cost float avg_steps_; // Average number of computation steps }; template using adaptive_computation_time = add_layer, SUBNET>; template using act = add_layer, SUBNET>; // Default 8 steps template using act4 = add_layer, SUBNET>; // Fast version template using act16 = add_layer, SUBNET>; // Deep version // ---------------------------------------------------------------------------------------- } #endif // DLIB_DNn_LAYERS_H_ ================================================ FILE: dlib/dnn/layers_abstract.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #undef DLIB_DNn_LAYERS_ABSTRACT_H_ #ifdef DLIB_DNn_LAYERS_ABSTRACT_H_ #include "../cuda/tensor_abstract.h" #include "core_abstract.h" #include "../cuda/operation_mode.h" namespace dlib { // ---------------------------------------------------------------------------------------- class SUBNET { /*! WHAT THIS OBJECT REPRESENTS This object represents a deep neural network. In particular, it is the simplified interface through which layer objects interact with their subnetworks. A layer's two important tasks are to (1) take outputs from its subnetwork and forward propagate them through itself and (2) to backwards propagate an error gradient through itself and onto its subnetwork. The idea of a subnetwork is illustrated in the following diagram: +---------------------------------------------------------+ | loss <-- layer1 <-- layer2 <-- ... <-- layern <-- input | +---------------------------------------------------------+ ^ ^ \__ subnetwork for layer1 __/ Therefore, by "subnetwork" we mean the part of the network closer to the input. Note that there is no dlib::SUBNET type. It is shown here purely to document the interface layer objects expect to see when they interact with a network. !*/ public: // You aren't allowed to copy subnetworks from inside a layer. SUBNET(const SUBNET&) = delete; SUBNET& operator=(const SUBNET&) = delete; const tensor& get_output( ) const; /*! ensures - returns the output of this subnetwork. This is the data that the next layer in the network will take as input. - have_same_dimensions(#get_gradient_input(), get_output()) == true !*/ tensor& get_gradient_input( ); /*! ensures - returns the error gradient for this subnetwork. That is, this is the error gradient that this network will use to update itself. Therefore, when performing back propagation, layers that sit on top of this subnetwork write their back propagated error gradients into get_gradient_input(). Or to put it another way, during back propagation, layers take the contents of their get_gradient_input() and back propagate it through themselves and store the results into their subnetwork's get_gradient_input(). !*/ const NEXT_SUBNET& subnet( ) const; /*! ensures - returns the subnetwork of *this network. With respect to the diagram above, if *this was layer1 then subnet() would return the network that begins with layer2. !*/ NEXT_SUBNET& subnet( ); /*! ensures - returns the subnetwork of *this network. With respect to the diagram above, if *this was layer1 then subnet() would return the network that begins with layer2. !*/ const INPUT_LAYER& input_layer( ) const; /*! ensures - returns the very first layer in *this network. It's equivalent to calling subnet() recursively until you get to the first layer. This means it will return the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined in input_abstract.h !*/ INPUT_LAYER& input_layer( ); /*! ensures - returns the very first layer in *this network. It's equivalent to calling subnet() recursively until you get to the first layer. This means it will return the object that is an implementation of the EXAMPLE_INPUT_LAYER interface defined in input_abstract.h !*/ const layer_details_type& layer_details( ) const; /*! ensures - returns the layer_details_type instance that defines the behavior of the layer at the top of this network. I.e. returns the layer details that defines the behavior of the layer nearest to the network output rather than the input layer. For computational layers, this is the object implementing the EXAMPLE_COMPUTATIONAL_LAYER_ interface that defines the layer's behavior. !*/ unsigned int sample_expansion_factor ( ) const; /*! ensures - When to_tensor() is invoked on this network's input layer it converts N input objects into M samples, all stored inside a resizable_tensor. It is always the case that M is some integer multiple of N. sample_expansion_factor() returns the value of this multiplier. To be very specific, it is always true that M==I*N where I is some integer. This integer I is what is returned by sample_expansion_factor(). It should be noted that computational layers likely do not care about the sample expansion factor. It is only really of concern inside a loss layer where you need to know its value so that tensor samples can be matched against truth objects. Moreover, in most cases the sample expansion factor is 1. !*/ }; // ---------------------------------------------------------------------------------------- class EXAMPLE_COMPUTATIONAL_LAYER_ { /*! WHAT THIS OBJECT REPRESENTS Each computational layer in a deep neural network can be thought of as a function, f(data,parameters), that takes in a data tensor, some parameters, and produces an output tensor. You create an entire deep network by composing these functions. Importantly, you are able to use a wide range of different functions to accommodate the task you are trying to accomplish. Therefore, dlib includes a number of common layer types but if you want to define your own then you simply implement a class with the same interface as EXAMPLE_COMPUTATIONAL_LAYER_. Note that there is no dlib::EXAMPLE_COMPUTATIONAL_LAYER_ type. It is shown here purely to document the interface that a layer object must implement. The central work of defining a layer is implementing the forward and backward methods. When you do this you have four options: - Implement the forward() and backward() methods according to the specification shown below. Do not implement forward_inplace() and backward_inplace(). - Implement the forward() and backward() methods according to the specification shown below, except exclude the computed_output parameter from backward(). Doing this will allow dlib to make some layers execute in-place and therefore run a little faster and use less memory. Do not implement forward_inplace() and backward_inplace(). - Implement the forward_inplace() and backward_inplace() methods according to the specification shown below. Do not implement forward() and backward(). These in-place methods allow some types of layers to be implemented more efficiently. - Implement the forward_inplace() and backward_inplace() methods according to the specification shown below, except exclude the computed_output parameter from backward_inplace(). Doing this will allow dlib to make some layers execute in-place and therefore run a little faster and use less memory. Do not implement forward() and backward(). It should also be noted that layers may define additional layer specific fields and the solvers can use these fields as they see fit. For example, some layers define get_learning_rate_multiplier() and get_weight_decay_multiplier() methods. The solvers that come with dlib look at these methods, if they exist, and adjust the learning rate or weight decay for that layer according to the multiplier. Therefore, you can add these methods to your layer types if you want, or even define new fields and new solvers that use those fields in some way. !*/ public: EXAMPLE_COMPUTATIONAL_LAYER_( ); /*! ensures - Default constructs this object. This function is not required to do anything in particular but it must exist, that is, it is required that layer objects be default constructable. !*/ EXAMPLE_COMPUTATIONAL_LAYER_ ( const EXAMPLE_COMPUTATIONAL_LAYER_& item ); /*! ensures - EXAMPLE_COMPUTATIONAL_LAYER_ objects are copy constructable !*/ EXAMPLE_COMPUTATIONAL_LAYER_( const some_other_layer_type& item ); /*! ensures - Constructs this object from item. This form of constructor is optional but it allows you to provide a conversion from one layer type to another. For example, the following code is valid only if my_layer2 can be constructed from my_layer1: relu>>>>> my_dnn1; relu>>>>> my_dnn2(my_dnn1); This kind of pattern is useful if you want to use one type of layer during training but a different type of layer during testing since it allows you to easily convert between related deep neural network types. Additionally, if you provide a constructor to build a layer from another layer type you should also write your layer's deserialize() routine such that it can read that other layer's serialized data in addition to your own serialized data. !*/ template void setup ( const SUBNET& sub ); /*! requires - SUBNET implements the SUBNET interface defined at the top of this file. ensures - performs any necessary initial memory allocations and/or sets parameters to their initial values prior to learning. Therefore, calling setup destroys any previously learned parameters. Also, typically setup() would look at the dimensions of the outputs of sub and configure the number of parameters in *this accordingly. !*/ template void forward( const SUBNET& sub, resizable_tensor& data_output ); /*! requires - SUBNET implements the SUBNET interface defined at the top of this file. - setup() has been called. ensures - Runs the output of the subnetwork through this layer and stores the results into #data_output. In particular, forward() can use any of the outputs in sub (e.g. sub.get_output(), sub.subnet().get_output(), etc.) to compute whatever it wants. !*/ template void backward( const tensor& computed_output, // this parameter is optional const tensor& gradient_input, SUBNET& sub, tensor& params_grad ); /*! requires - SUBNET implements the SUBNET interface defined at the top of this file. - setup() has been called. - computed_output is the tensor resulting from calling forward(sub,computed_output). Moreover, this was the most recent call to forward(). This means that forward() is allowed to cache intermediate results so they can be used during the backward computation. - have_same_dimensions(gradient_input, computed_output) == true - have_same_dimensions(sub.get_gradient_input(), sub.get_output()) == true - have_same_dimensions(params_grad, get_layer_params()) == true ensures - This function outputs the gradients of this layer with respect to the input data from sub and also with respect to this layer's parameters. These gradients are stored into #sub and #params_grad, respectively. To be precise, the gradients are taken of a function f(sub,get_layer_params()) which is defined thusly: - Recalling that computed_output is a function of both sub and get_layer_params(), since it is the result of calling forward(sub,computed_output): let f(sub,get_layer_params()) == dot(computed_output, gradient_input) Then we define the following gradient vectors: - PARAMETER_GRADIENT == gradient of f(sub,get_layer_params()) with respect to get_layer_params(). - for all valid I: - DATA_GRADIENT_I == gradient of f(sub,get_layer_params()) with respect to layer(sub).get_output() (recall that forward() can draw inputs from the immediate sub layer, sub.subnet(), or any earlier layer. So you must consider the gradients with respect to all inputs drawn from sub) Finally, backward() outputs these gradients by performing: - params_grad = PARAMETER_GRADIENT - for all valid I: - layer(sub).get_gradient_input() += DATA_GRADIENT_I !*/ void forward_inplace( const tensor& data_input, tensor& data_output ); /*! requires - have_same_dimensions(data_input,data_output) == true - setup() has been called. ensures - Runs the data_input tensor through this layer and stores the output into #data_output. - This function supports in-place operation, i.e. having is_same_object(data_input, data_output)==true !*/ void backward_inplace( const tensor& computed_output, // this parameter is optional const tensor& gradient_input, tensor& data_grad, tensor& params_grad ); /*! requires - setup() has been called. - computed_output is the tensor resulting from the most recent call to forward_inplace(). This means that forward_inplace() is allowed to cache intermediate results so they can be used during the backward computation. - have_same_dimensions(gradient_input, data_grad) == true - have_same_dimensions(gradient_input, computed_output) == true - have_same_dimensions(params_grad, get_layer_params()) == true ensures - This function supports in-place operation, i.e. having is_same_object(gradient_input, data_grad)==true - This function outputs the gradients of this layer with respect to the input data from a sublayer and also with respect to this layer's parameters. These gradients are stored into #data_grad and #params_grad, respectively. To be precise, the gradients are taken of a function f(data_input,get_layer_params()) which is defined thusly: - Recalling that computed_output is a function of both the input to forward_inplace() and get_layer_params(), since it is the result of calling forward_inplace(data_input,computed_output): let f(data_input,get_layer_params()) == dot(computed_output, gradient_input) Then we define the following gradient vectors: - PARAMETER_GRADIENT == gradient of f(data_input,get_layer_params()) with respect to get_layer_params(). - DATA_GRADIENT == gradient of f(data_input,get_layer_params()) with respect to data_input. Finally, backward_inplace() outputs these gradients by performing: - params_grad = PARAMETER_GRADIENT - if (is_same_object(gradient_input, data_grad)) then - data_grad = DATA_GRADIENT - else - data_grad += DATA_GRADIENT !*/ const tensor& get_layer_params( ) const; /*! ensures - returns the parameters that define the behavior of forward(). !*/ tensor& get_layer_params( ); /*! ensures - returns the parameters that define the behavior of forward(). !*/ dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; /*! These two functions are optional. If provided, they should map between (column,row) coordinates in input and output tensors of forward(). Providing these functions allows you to use global utility functions like input_tensor_to_output_tensor(). !*/ void clean ( ); /*! Implementing this function is optional. If you don't need it then you don't have to provide a clean(). But if you do provide it then it must behave as follows: ensures - calling clean() causes this object to forget about everything except its parameters. This is useful if your layer caches information between forward and backward passes and you want to clean out that cache information before saving the network to disk. !*/ }; std::ostream& operator<<(std::ostream& out, const EXAMPLE_COMPUTATIONAL_LAYER_& item); /*! print a string describing this layer. !*/ void to_xml(const EXAMPLE_COMPUTATIONAL_LAYER_& item, std::ostream& out); /*! This function is optional, but required if you want to print your networks with net_to_xml(). Therefore, to_xml() prints a layer as XML. !*/ void serialize(const EXAMPLE_COMPUTATIONAL_LAYER_& item, std::ostream& out); void deserialize(EXAMPLE_COMPUTATIONAL_LAYER_& item, std::istream& in); /*! provides serialization support !*/ // For each layer you define, always define an add_layer template so that layers can be // easily composed. Moreover, the convention is that the layer class ends with an _ // while the add_layer template has the same name but without the trailing _. template using EXAMPLE_COMPUTATIONAL_LAYER = add_layer; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- enum fc_bias_mode { FC_HAS_BIAS = 0, FC_NO_BIAS = 1 }; struct num_fc_outputs { num_fc_outputs(unsigned long n) : num_outputs(n) {} unsigned long num_outputs; }; template < unsigned long num_outputs, fc_bias_mode bias_mode > class fc_ { /*! REQUIREMENTS ON num_outputs num_outputs > 0 WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a fully connected layer that takes an input tensor and multiplies it by a weight matrix and outputs the results. The dimensions of the tensors output by this layer are as follows (letting IN be the input tensor and OUT the output tensor): - OUT.num_samples() == IN.num_samples() - OUT.k() == get_num_outputs() - OUT.nr() == 1 - OUT.nc() == 1 !*/ public: fc_( ); /*! ensures - #get_num_outputs() == num_outputs - #get_bias_mode() == bias_mode - #get_learning_rate_multiplier() == 1 - #get_weight_decay_multiplier() == 1 - #get_bias_learning_rate_multiplier() == 1 - #get_bias_weight_decay_multiplier() == 0 !*/ fc_( num_fc_outputs o ); /*! ensures - #get_num_outputs() == o.num_outputs - #get_bias_mode() == bias_mode - #get_learning_rate_multiplier() == 1 - #get_weight_decay_multiplier() == 1 - #get_bias_learning_rate_multiplier() == 1 - #get_bias_weight_decay_multiplier() == 0 !*/ unsigned long get_num_outputs ( ) const; /*! ensures - This layer outputs column vectors that contain get_num_outputs() elements. That is, the output tensor T from forward() will be such that: - T.num_samples() == however many samples were given to forward(). - T.k() == get_num_outputs() - The rest of the dimensions of T will be 1. !*/ void set_num_outputs( long num ); /*! requires - num > 0 - get_layer_params().size() == 0 || get_num_outputs() == num (i.e. You can't change the number of outputs in fc_ if the parameter tensor has already been allocated.) ensures - #get_num_outputs() == num !*/ fc_bias_mode get_bias_mode ( ) const; /*! ensures - returns the bias mode which determines if this layer includes bias terms. That is, if the bias mode is FC_HAS_BIAS then a different constant scalar is added to each of the outputs of this layer. !*/ double get_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the learning rate used to optimize its parameters be multiplied by get_learning_rate_multiplier(). !*/ double get_weight_decay_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the weight decay used to optimize its parameters be multiplied by get_weight_decay_multiplier(). !*/ void set_learning_rate_multiplier( double val ); /*! requires - val >= 0 ensures - #get_learning_rate_multiplier() == val !*/ void set_weight_decay_multiplier( double val ); /*! requires - val >= 0 ensures - #get_weight_decay_multiplier() == val !*/ double get_bias_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the learning rate used to optimize its bias parameters be multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier(). !*/ double get_bias_weight_decay_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the weight decay used to optimize its bias parameters be multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier(). !*/ void set_bias_learning_rate_multiplier( double val ); /*! requires - val >= 0 ensures - #get_bias_learning_rate_multiplier() == val !*/ void set_bias_weight_decay_multiplier( double val ); /*! requires - val >= 0 ensures - #get_bias_weight_decay_multiplier() == val !*/ void disable_bias( ); /*! ensures - bias_is_disabled() returns true !*/ bool bias_is_disabled( ) const; /*! ensures - returns true if bias learning is disabled for this layer. This means the biases will not be learned during the training and they will not be used in the forward or backward methods either. !*/ alias_tensor_const_instance get_weights( ) const; /*! ensures - returns an alias of get_layer_params(), containing the weights matrix of the fully connected layer. - #get_weights().num_samples() is the number of elements in input sample, i.e. sublayer's output's k * nc * nr. - #get_bias().k() == #get_num_outputs() - if get_bias_mode() == FC_HAS_BIAS: - #get_layer_params().size() == (#get_weights().size() + #get_biases().size()) - else: - #get_layer_params().size() == #get_weights().size() !*/ alias_tensor_instance get_weights( ); /*! ensures - returns an alias of get_layer_params(), containing the weights matrix of the fully connected layer. - #get_weights().num_samples() is the number of elements in input sample, i.e. sublayer's output's k * nc * nr. - #get_bias().k() == #get_num_outputs() - if get_bias_mode() == FC_HAS_BIAS: - #get_layer_params().size() == (#get_weights().size() + #get_biases().size()) - else: - #get_layer_params().size() == #get_weights().size() !*/ alias_tensor_const_instance get_biases( ) const; /*! requires - #get_bias_mode() == FC_HAS_BIAS ensures - returns an alias of get_layer_params(), containing the bias vector of the fully connected layer. - #get_bias().num_samples() == 1 - #get_bias().k() == #get_num_outputs() - #get_layer_params().size() == (#get_weights().size() + #get_biases().size()) !*/ alias_tensor_instance get_biases( ); /*! requires - #get_bias_mode() == FC_HAS_BIAS ensures - returns an alias of get_layer_params(), containing the bias vector of the fully connected layer. - #get_bias().num_samples() == 1 - #get_bias().k() == #get_num_outputs() - #get_layer_params().size() == (#get_weights().size() + #get_biases().size()) !*/ template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < unsigned long num_outputs, typename SUBNET > using fc = add_layer, SUBNET>; template < unsigned long num_outputs, typename SUBNET > using fc_no_bias = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- enum linear_bias_mode { LINEAR_HAS_BIAS, LINEAR_NO_BIAS }; template < unsigned long num_outputs, linear_bias_mode bias_mode = LINEAR_HAS_BIAS > class linear_ { /*! REQUIREMENTS ON num_outputs num_outputs > 0 WHAT THIS OBJECT REPRESENTS This is an implementation of a linear layer, which applies a linear transformation to the input data. For a layer with bias, the transformation is: output = input * weights + bias For a layer without bias, it's simply: output = input * weights The input tensor can have any number of sample, k (channel), and nr (row) dimensions, but the nc (column) dimension must match the number of input features. The output tensor will have the same dimensions as the input tensor, except for the nc dimension which will be equal to num_outputs. This layer is similar to the fc_ layer, but optimized for the case where the input and output tensors maintain the same dimensions, excluding the feature dimension (nc). This makes it useful for working with multi-dimensional data. !*/ public: linear_( ); /*! ensures - #get_num_outputs() == num_outputs - #get_bias_mode() == bias_mode - #get_learning_rate_multiplier() == 1 !*/ double get_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier that will be applied to the gradient of this layer during training. This value appears as a multiplicative factor in the update rule. So if get_learning_rate_multiplier() == 1 then the learning rate will be multiplied by 1 and thus not modified. However, if get_learning_rate_multiplier() == 0.1 then the learning rate will be multiplied by 0.1, making the layer update 10 times slower than it would otherwise be. !*/ void set_learning_rate_multiplier( double val ); /*! ensures - #get_learning_rate_multiplier() == val !*/ unsigned long get_num_inputs( ) const; /*! ensures - Returns the number of input features this layer expects. - For an uninitialized layer (i.e., one that has not seen any data during setup or forward pass), this will be zero. !*/ unsigned long get_num_outputs( ) const; /*! ensures - Returns the number of output features this layer produces. I.e., this value is num_outputs. !*/ void set_num_outputs( long num ); /*! requires - num > 0 ensures - #get_num_outputs() == num throws - std::runtime_error if this function is called after the layer parameters have been allocated and the new number of outputs doesn't match the previously set number of outputs. !*/ linear_bias_mode get_bias_mode( ) const; /*! ensures - Returns a value indicating whether this layer has a bias term. I.e. returns bias_mode. !*/ template void setup( const SUBNET& sub ); /*! ensures - Performs the necessary setup work to process data through this layer. - Sets the input size based on the dimensions of the input tensor from sub. - Allocates the parameter tensor and initializes its values. - #get_num_inputs() == the number of columns in sub.get_output() (i.e., nc). !*/ template void forward( const SUBNET& sub, resizable_tensor& output ); /*! requires - setup() has been called - sub.get_output().nc() == get_num_inputs() ensures - Applies the linear transformation to the input tensor from sub and stores the results in output. - #output.num_samples() == sub.get_output().num_samples() - #output.k() == sub.get_output().k() - #output.nr() == sub.get_output().nr() - #output.nc() == get_num_outputs() !*/ template void backward( const tensor& gradient_input, SUBNET& sub, tensor& params_grad ); /*! requires - setup() has been called - sub.get_output().nc() == get_num_inputs() - gradient_input has the same dimensions as the output of forward() ensures - Computes the gradients of this layer with respect to the parameters and the input tensor, and updates the corresponding gradient tensors. - Updates params_grad based on the gradients of the weights and biases (if present). - Updates sub's gradient_input based on the gradients of the inputs to this layer. !*/ alias_tensor_instance get_weights( ); /*! requires - setup() has been called ensures - Returns a reference to the weights matrix of this layer. !*/ alias_tensor_const_instance get_weights( ) const; /*! requires - setup() has been called ensures - Returns a const reference to the weights matrix of this layer. !*/ alias_tensor_instance get_biases( ); /*! requires - bias_mode == LINEAR_HAS_BIAS - setup() has been called ensures - Returns a reference to the bias vector of this layer. throws - static_assert failure if bias_mode != LINEAR_HAS_BIAS !*/ alias_tensor_const_instance get_biases( ) const; /*! requires - bias_mode == LINEAR_HAS_BIAS - setup() has been called ensures - Returns a const reference to the bias vector of this layer. throws - static_assert failure if bias_mode != LINEAR_HAS_BIAS !*/ dpoint map_input_to_output( const dpoint& p ) const; /*! ensures - Returns p, since the linear layer maintains the same spatial dimensions. !*/ dpoint map_output_to_input( const dpoint& p ) const; /*! ensures - Returns p, since the linear layer maintains the same spatial dimensions. !*/ const tensor& get_layer_params( ) const; /*! ensures - Returns the parameters that define this layer, i.e., the weights and biases (if present) that are updated during training. !*/ tensor& get_layer_params( ); /*! ensures - Returns the parameters that define this layer, i.e., the weights and biases (if present) that are updated during training. !*/ friend void serialize(const linear_& item, std::ostream& out); friend void deserialize(linear_& item, std::istream& in); /*! provides serialization support !*/ }; template < unsigned long num_outputs, typename SUBNET > using linear = add_layer, SUBNET>; /*! This is a layer that applies a linear transformation with bias to the input: output = input * weights + bias !*/ template < unsigned long num_outputs, typename SUBNET > using linear_no_bias = add_layer, SUBNET>; /*! This is a layer that applies a linear transformation without bias to the input: output = input * weights !*/ // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- struct num_con_outputs { num_con_outputs(unsigned long n) : num_outputs(n) {} unsigned long num_outputs; }; template < long _num_filters, long _nr, long _nc, int _stride_y, int _stride_x, int _padding_y = _stride_y!=1? 0 : _nr/2, int _padding_x = _stride_x!=1? 0 : _nc/2 > class con_ { /*! REQUIREMENTS ON TEMPLATE ARGUMENTS - _num_filters > 0 - _nr >= 0 - _nc >= 0 - _stride_y > 0 - _stride_x > 0 - _padding_y >= 0 - _padding_x >= 0 - Also, we require that: - if (_nr == 0) then - _padding_y == 0 - else - _padding_y < _nr - if (_nc == 0) then - _padding_x == 0 - else - _padding_x < _nc WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a convolution layer that takes an input tensor (nominally representing an image) and convolves it with a set of filters and then outputs the results. The dimensions of the tensors output by this layer are as follows (letting IN be the input tensor and OUT the output tensor): - OUT.num_samples() == IN.num_samples() - OUT.k() == num_filters() - OUT.nr() == 1+(IN.nr() + 2*padding_y() - nr())/stride_y() - OUT.nc() == 1+(IN.nc() + 2*padding_x() - nc())/stride_x() Note also that setting _nr or _nc to 0 has a special meaning of "set the filter size equal to the input image size". Specifically, it means: - if (_nr == 0) then - nr() == IN.nr() - OUT.nr() == 1 - if (_nc == 0) then - nc() == IN.nc() - OUT.nc() == 1 !*/ public: con_( ); /*! ensures - #num_filters() == _num_filters - #nr() == _nr - #nc() == _nc - #stride_y() == _stride_y - #stride_x() == _stride_x - #padding_y() == _padding_y - #padding_x() == _padding_x - #get_learning_rate_multiplier() == 1 - #get_weight_decay_multiplier() == 1 - #get_bias_learning_rate_multiplier() == 1 - #get_bias_weight_decay_multiplier() == 0 !*/ con_( num_con_outputs o ); /*! ensures - #num_filters() == o.num_outputs - #nr() == _nr - #nc() == _nc - #stride_y() == _stride_y - #stride_x() == _stride_x - #padding_y() == _padding_y - #padding_x() == _padding_x - #get_learning_rate_multiplier() == 1 - #get_weight_decay_multiplier() == 1 - #get_bias_learning_rate_multiplier() == 1 - #get_bias_weight_decay_multiplier() == 0 !*/ long num_filters( ) const; /*! ensures - returns the number of filters contained in this layer. The k dimension of the output tensors produced by this layer will be equal to the number of filters. !*/ void set_num_filters( long num ); /*! requires - num > 0 - get_layer_params().size() == 0 || num_filters() == num (i.e. You can't change the number of filters in con_ if the parameter tensor has already been allocated.) ensures - #num_filters() == num !*/ long nr( ) const; /*! ensures - returns the number of rows in the filters in this layer. Note that if nr()==0 then it means the size of the filter is not yet assigned, but once setup() is called nr() will be set to the input tensor's nr(). Therefore, nr()==0 has the special interpretation of "be the same size as the input tensor". !*/ long nc( ) const; /*! ensures - returns the number of columns in the filters in this layer. Note that if nc()==0 then it means the size of the filter is not yet assigned, but once setup() is called nc() will be set to the input tensor's nc(). Therefore, nc()==0 has the special interpretation of "be the same size as the input tensor". !*/ long stride_y( ) const; /*! ensures - returns the vertical stride used when convolving the filters over an image. That is, each filter will be moved stride_y() pixels down at a time when it moves over the image. !*/ long stride_x( ) const; /*! ensures - returns the horizontal stride used when convolving the filters over an image. That is, each filter will be moved stride_x() pixels right at a time when it moves over the image. !*/ long padding_y( ) const; /*! ensures - returns the number of pixels of zero padding added to the top and bottom sides of the image. !*/ long padding_x( ) const; /*! ensures - returns the number of pixels of zero padding added to the left and right sides of the image. !*/ double get_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the learning rate used to optimize its parameters be multiplied by get_learning_rate_multiplier(). !*/ double get_weight_decay_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the weight decay used to optimize its parameters be multiplied by get_weight_decay_multiplier(). !*/ void set_learning_rate_multiplier( double val ); /*! requires - val >= 0 ensures - #get_learning_rate_multiplier() == val !*/ void set_weight_decay_multiplier( double val ); /*! requires - val >= 0 ensures - #get_weight_decay_multiplier() == val !*/ double get_bias_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the learning rate used to optimize its bias parameters be multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier(). !*/ double get_bias_weight_decay_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the weight decay used to optimize its bias parameters be multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier(). !*/ void set_bias_learning_rate_multiplier( double val ); /*! requires - val >= 0 ensures - #get_bias_learning_rate_multiplier() == val !*/ void set_bias_weight_decay_multiplier( double val ); /*! requires - val >= 0 ensures - #get_bias_weight_decay_multiplier() == val !*/ void disable_relu( ); /*! ensures - relu_is_disabled() returns true !*/ void enable_relu( ); /*! ensures - relu_is_disabled() returns false !*/ bool relu_is_disabled( ) const; /*! ensures - returns true if relu is disabled for this layer. This means no activation function will be applied after the convolution when calling forward. !*/ void disable_bias( ); /*! ensures - bias_is_disabled() returns true - if bias was enabled and allocated, it resizes the layer parameters to accommodate the filter parameters only, and free the bias parameters. !*/ void enable_bias( ); /*! ensures - bias_is_disabled() returns false - if bias was disabled and not allocated, it resizes the layer parameters to accommodate the new zero-inizialized biases !*/ bool bias_is_disabled( ) const; /*! ensures - returns true if bias learning is disabled for this layer. This means the biases will not be learned during the training and they will not be used in the forward or backward methods either. !*/ template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < long num_filters, long nr, long nc, int stride_y, int stride_x, typename SUBNET > using con = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < long _num_filters, long _nr, long _nc, int _stride_y, int _stride_x, int _padding_y = _stride_y!=1? 0 : _nr/2, int _padding_x = _stride_x!=1? 0 : _nc/2 > class cont_ { /*! REQUIREMENTS ON TEMPLATE ARGUMENTS All of them must be > 0. Also, we require that: - 0 <= _padding_y && _padding_y < _nr - 0 <= _padding_x && _padding_x < _nc WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a transposed convolution layer that takes an input tensor and transpose convolves (sometimes called "deconvolution") it with a set of filters and then outputs the results. This is essentially a convolutional layer that allows fractional strides. Therefore, you can make output tensors that are larger than the input tensors using this layer type. The dimensions of the tensors output by this layer are as follows (letting IN be the input tensor and OUT the output tensor): - OUT.num_samples() == IN.num_samples() - OUT.k() == num_filters() - OUT.nr() == stride_y()*(IN.nr()-1) + nr() - 2*padding_y() - OUT.nc() == stride_x()*(IN.nc()-1) + nc() - 2*padding_x() !*/ public: cont_( ); /*! ensures - #num_filters() == _num_filters - #nr() == _nr - #nc() == _nc - #stride_y() == _stride_y - #stride_x() == _stride_x - #padding_y() == _padding_y - #padding_x() == _padding_x - #get_learning_rate_multiplier() == 1 - #get_weight_decay_multiplier() == 1 - #get_bias_learning_rate_multiplier() == 1 - #get_bias_weight_decay_multiplier() == 0 !*/ cont_( num_con_outputs o ); /*! ensures - #num_filters() == o.num_outputs - #nr() == _nr - #nc() == _nc - #stride_y() == _stride_y - #stride_x() == _stride_x - #padding_y() == _padding_y - #padding_x() == _padding_x - #get_learning_rate_multiplier() == 1 - #get_weight_decay_multiplier() == 1 - #get_bias_learning_rate_multiplier() == 1 - #get_bias_weight_decay_multiplier() == 0 !*/ long num_filters( ) const; /*! ensures - returns the number of filters contained in this layer. The k dimension of the output tensors produced by this layer will be equal to the number of filters. !*/ void set_num_filters( long num ); /*! requires - num > 0 - get_layer_params().size() == 0 || num_filters() == num (i.e. You can't change the number of filters in cont_ if the parameter tensor has already been allocated.) ensures - #num_filters() == num !*/ long nr( ) const; /*! ensures - returns the number of rows in the filters in this layer. !*/ long nc( ) const; /*! ensures - returns the number of columns in the filters in this layer. !*/ long stride_y( ) const; /*! ensures - returns the vertical stride used when convolving the filters over an image. That is, each filter will be moved 1.0/stride_y() pixels down at a time when it moves over the image. !*/ long stride_x( ) const; /*! ensures - returns the horizontal stride used when convolving the filters over an image. That is, each filter will be moved 1.0/stride_x() pixels right at a time when it moves over the image. !*/ long padding_y( ) const; /*! ensures - returns the number of pixels of zero padding added to the top and bottom sides of the image. !*/ long padding_x( ) const; /*! ensures - returns the number of pixels of zero padding added to the left and right sides of the image. !*/ double get_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the learning rate used to optimize its parameters be multiplied by get_learning_rate_multiplier(). !*/ double get_weight_decay_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the weight decay used to optimize its parameters be multiplied by get_weight_decay_multiplier(). !*/ void set_learning_rate_multiplier( double val ); /*! requires - val >= 0 ensures - #get_learning_rate_multiplier() == val !*/ void set_weight_decay_multiplier( double val ); /*! requires - val >= 0 ensures - #get_weight_decay_multiplier() == val !*/ double get_bias_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the learning rate used to optimize its bias parameters be multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier(). !*/ double get_bias_weight_decay_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the weight decay used to optimize its bias parameters be multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier(). !*/ void set_bias_learning_rate_multiplier( double val ); /*! requires - val >= 0 ensures - #get_bias_learning_rate_multiplier() == val !*/ void set_bias_weight_decay_multiplier( double val ); /*! requires - val >= 0 ensures - #get_bias_weight_decay_multiplier() == val !*/ void disable_bias( ); /*! ensures - bias_is_disabled() returns true !*/ bool bias_is_disabled( ) const; /*! ensures - returns true if bias learning is disabled for this layer. This means the biases will not be learned during the training and they will not be used in the forward or backward methods either. !*/ template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < long num_filters, long nr, long nc, int stride_y, int stride_x, typename SUBNET > using cont = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < int scale_y, int scale_x > class upsample_ { /*! REQUIREMENTS ON TEMPLATE ARGUMENTS All of them must be >= 1. WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it allows you to upsample a layer using bilinear interpolation. To be very specific, it upsamples each of the channels in an input tensor. Therefore, if IN is the input tensor to this layer and OUT the output tensor, then we will have: - OUT.num_samples() == IN.num_samples() - OUT.k() == IN.k() - OUT.nr() == IN.nr()*scale_y - OUT.nc() == IN.nc()*scale_x - for all valid i,k: image_plane(OUT,i,k) is a copy of image_plane(IN,i,k) that has been bilinearly interpolated to fit into the shape of image_plane(OUT,i,k). !*/ public: upsample_( ); /*! ensures - This object has no state, so the constructor does nothing, aside from providing default constructability. !*/ template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < int scale, typename SUBNET > using upsample = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < long NR_, long NC_ > class resize_to_ { /*! REQUIREMENTS ON THE INPUT ARGUMENTS - NR_ >= 1 - NC_ >= 1 WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it allows you to resize a layer using bilinear interpolation. To be very specific, it resizes each of the channels in an input tensor. Therefore, if IN is the input tensor to this layer and OUT the output tensor, then we will have: - OUT.num_samples() == IN.num_samples() - OUT.k() == IN.k() - OUT.nr() == NR_ - OUT.nc() == NC_ - for all valid i,k: image_plane(OUT,i,k) is a copy of image_plane(IN,i,k) that has been bilinearly interpolated to fit into the shape of image_plane(OUT,i,k). !*/ public: resize_to_( ); /*! ensures - This object has no state, so the constructor does nothing, aside from providing default constructability. !*/ template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < long NR, long NC, typename SUBNET > using resize_to = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template class reshape_to_ { /*! REQUIREMENTS ON TEMPLATE ARGUMENTS - k_, nr_, and nc_ must be either -1 or greater than 0. WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. It defines a layer that reshapes or resizes an input tensor into a different shape. The layer operates in two modes: 1. Pure Reshape Mode: When the total number of elements in the input tensor equals the total number of elements in the output tensor, this layer performs a simple reshaping operation without changing the values. 2. Spatial Rescaling Mode: When the channel dimension (k) remains constant but the total number of elements changes, this layer performs bilinear interpolation to resize the spatial dimensions while preserving the channel information. The dimensions of the output tensor are determined by the template parameters: - If k_ is -1, the output tensor will have the same number of channels as the input. - If nr_ is -1, the output tensor will have the same number of rows as the input. - If nc_ is -1, the output tensor will have the same number of columns as the input. Setting a value of -1 for any dimension means "keep the original dimension from the input." Note that this layer will throw an exception if you attempt to change both the channel count (k) and the total number of elements. Either: - Keep the total number of elements the same (Pure Reshape Mode), or - Keep the channel count the same and only change spatial dimensions (Spatial Rescaling Mode) !*/ public: explicit reshape_to_(); /*! ensures - #get_output_k() == k_ - #get_output_nr() == nr_ - #get_output_nc() == nc_ !*/ long get_output_k() const; /*! ensures - Returns the number of channels in the output tensor. If this value is -1, then the output will have the same number of channels as the input. !*/ long get_output_nr() const; /*! ensures - Returns the number of rows in the output tensor. If this value is -1, then the output will have the same number of rows as the input. !*/ long get_output_nc() const; /*! ensures - Returns the number of columns in the output tensor. If this value is -1, then the output will have the same number of columns as the input. !*/ void set_output_k(long k); /*! requires - k == -1 || k > 0 ensures - #get_output_k() == k !*/ void set_output_nr(long nr); /*! requires - nr == -1 || nr > 0 ensures - #get_output_nr() == nr !*/ void set_output_nc(long nc); /*! requires - nc == -1 || nc > 0 ensures - #get_output_nc() == nc !*/ template void setup(const SUBNET& sub); /*! requires - SUBNET implements the SUBNET interface defined at the top of this file. ensures - Configures this layer to operate on the output of sub. - If the total number of elements in the input tensor doesn't match the total number of elements in the output tensor and the channel dimension is different, an exception will be thrown. !*/ template void forward(const SUBNET& sub, resizable_tensor& output); /*! requires - SUBNET implements the SUBNET interface defined at the top of this file. - setup() has been called. ensures - Reshapes or resizes the output of sub and stores it in #output. - If is_spatial_rescale() == false, then performs a pure reshape operation. - If is_spatial_rescale() == true, then performs bilinear interpolation to resize the spatial dimensions while preserving the channel information. - #output.num_samples() == sub.get_output().num_samples() - #output.k() == get_output_k() if get_output_k() != -1, otherwise sub.get_output().k() - #output.nr() == get_output_nr() if get_output_nr() != -1, otherwise sub.get_output().nr() - #output.nc() == get_output_nc() if get_output_nc() != -1, otherwise sub.get_output().nc() !*/ template void backward( const tensor& gradient_input, SUBNET& sub, tensor& params_grad ); /*! requires - SUBNET implements the SUBNET interface defined at the top of this file. - setup() has been called. - gradient_input has the same dimensions as the output of forward(). ensures - Computes the gradients of this layer with respect to the input tensor and parameters, and stores them in sub.get_gradient_input() and params_grad, respectively. - This function supports both pure reshaping and spatial rescaling operations. !*/ dpoint map_input_to_output(dpoint p) const; /*! ensures - Maps a point in the input tensor's coordinate system to the corresponding point in the output tensor. This is useful for tracking how spatial locations change through the network, especially during spatial rescaling. !*/ dpoint map_output_to_input(dpoint p) const; /*! ensures - Maps a point in the output tensor's coordinate system to the corresponding point in the input tensor. This is the inverse of map_input_to_output(). !*/ const tensor& get_layer_params() const; /*! ensures - Returns the layer's parameters. This layer has no parameters, so this always returns an empty tensor. !*/ tensor& get_layer_params(); /*! ensures - Returns the layer's parameters. This layer has no parameters, so this always returns an empty tensor. !*/ }; template using reshape_to = add_layer, SUBNET>; template using flatten = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- class dropout_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a dropout layer. Therefore, it passes its inputs through the stochastic function f(x) which outputs either 0 or x. The probability of 0 being output is given by the drop_rate argument to this object's constructor. Note that, after you finish training a network with dropout, it is a good idea to replace each dropout_ layer with a multiply_ layer because the multiply_ layer is faster and deterministic. !*/ public: explicit dropout_( float drop_rate = 0.5 ); /*! requires - 0 <= drop_rate <= 1 ensures - #get_drop_rate() == drop_rate !*/ float get_drop_rate ( ) const; /*! ensures - returns the probability that an individual input value to this layer will be replaced with 0. !*/ template void setup (const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& gradient_input, tensor& data_grad, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template using dropout = add_layer; // ---------------------------------------------------------------------------------------- template class dropout_rate_ : public dropout_ { /*! WHAT THIS OBJECT REPRESENTS This object represents a customizable dropout layer that inherits from the dropout_ class. It allows specifying the dropout rate at compile-time, which is particularly useful for deep networks with many layers where it might be cumbersome to explicitly modify the dropout rate for each layer individually. The main advantage of this layer is that it offers the possibility to specify the dropout rate at the moment of network construction, providing more flexibility and clarity in the network architecture definition. TEMPLATE PARAMETERS - DROP_RATE_PERCENT: A int value between 0 and 100 that specifies the dropout rate. This value is set at compile-time and cannot be changed during runtime. !*/ public: explicit dropout_rate_(); /*! ensures - Constructs a dropout layer with a dropout rate of DROP_RATE. - Calls the base class constructor dropout_(DROP_RATE). !*/ }; template using dropout_rate = add_layer, SUBNET>; template using dropout_10 = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- class multiply_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a basic layer that just multiplies its input tensor with a constant value and returns the result. It therefore has no learnable parameters. !*/ public: explicit multiply_( float val = 0.5 ); /*! ensures - #get_multiply_value() == val !*/ multiply_ ( const dropout_& item ); /*! ensures - #get_multiply_value() == 1-item.get_drop_rate() (i.e. We construct the multiply_ layer so that it is essentially a deterministic version of the given dropout_ layer) !*/ float get_multiply_value ( ) const; /*! ensures - this layer simply multiplies its input tensor by get_multiply_value() and produces the result as output. !*/ template void setup (const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& gradient_input, tensor& data_grad, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template using multiply = add_layer; // ---------------------------------------------------------------------------------------- const double DEFAULT_LAYER_NORM_EPS = 1e-5; class layer_norm_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a batch normalization layer that implements the method described in the paper: Layer Normalization by Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton In particular, this layer produces output tensors with the same dimensionality as the input tensors, except that the mean and variances of the elements in each sample have been standardized to 0 and 1 respectively. This is different from batch normalization, since this layer learns one scaling factor and one bias for each sample in the batch, independently. As a result, this layer is batch-size independent. !*/ public: layer_norm_( ); /*! ensures - #get_learning_rate_multiplier() == 1 - #get_weight_decay_multiplier() == 0 - #get_bias_learning_rate_multiplier() == 1 - #get_bias_weight_decay_multiplier() == 1 - #get_eps() == DEFAULT_LAYER_NORM_EPS !*/ explicit layer_norm_( double eps_ = DEFAULT_LAYER_NORM_EPS ) /*! requires - eps > 0 ensures - #get_learning_rate_multiplier() == 1 - #get_weight_decay_multiplier() == 0 - #get_bias_learning_rate_multiplier() == 1 - #get_bias_weight_decay_multiplier() == 1 - #get_eps() == eps !*/ double get_eps( ) const; /*! ensures - When doing layer normalization, we are dividing by the standard deviation. This epsilon value returned by this function is added to the variance to prevent the division from dividing by zero. !*/ double get_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the learning rate used to optimize its parameters be multiplied by get_learning_rate_multiplier(). !*/ double get_weight_decay_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the weight decay used to optimize its parameters be multiplied by get_weight_decay_multiplier(). !*/ void set_learning_rate_multiplier( double val ); /*! requires - val >= 0 ensures - #get_learning_rate_multiplier() == val !*/ void set_weight_decay_multiplier( double val ); /*! requires - val >= 0 ensures - #get_weight_decay_multiplier() == val !*/ double get_bias_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the learning rate used to optimize its bias parameters be multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier(). !*/ double get_bias_weight_decay_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the weight decay used to optimize its bias parameters be multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier(). !*/ void set_bias_learning_rate_multiplier( double val ); /*! requires - val >= 0 ensures - #get_bias_learning_rate_multiplier() == val !*/ void set_bias_weight_decay_multiplier( double val ); /*! requires - val >= 0 ensures - #get_bias_weight_decay_multiplier() == val !*/ template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; // ---------------------------------------------------------------------------------------- const float DEFAULT_RMS_NORM_EPS = 1e-5f; class rms_norm_ { /*! WHAT THIS OBJECT REPRESENTS This object implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above, specifically defining a root mean square (RMS) normalization layer. RMS normalization is a technique that normalizes the input tensor based on the root mean square (RMS) of its elements. Unlike traditional layer normalization, which both centers and scales the data, RMS normalization only scales by the RMS value. This makes it computationally more efficient, as it avoids the need to compute the mean and subtract it from each element. This layer produces output tensors with the same dimensionality as the input tensors. Specifically, for an input tensor with shape [num_samples, k, nr, nc], the RMS normalization is applied across the [nr, nc] dimensions independently for each element in the [k] dimension and for each sample in the [num_samples] dimension. The scaling factor (RMS) and the learnable scaling parameter (gamma) are both of size [k]. The key characteristics of this layer are: - The RMS of the elements in each sample is standardized to 1. - It does not center the data (i.e., it does not subtract the mean). - A learnable scaling factor (gamma) is applied after normalization, allowing the model to adapt the scaling dynamically. This layer is particularly effective in various natural language processing tasks, where it has been shown to provide performance similar to or better than traditional layer normalization, with reduced computational overhead. !*/ public: rms_norm_( ); /*! ensures - #get_learning_rate_multiplier() == 1 - #get_weight_decay_multiplier() == 0 - #get_bias_learning_rate_multiplier() == 1 - #get_bias_weight_decay_multiplier() == 1 - #get_eps() == DEFAULT_RMS_NORM_EPS !*/ explicit rms_norm_( float eps_ = DEFAULT_RMS_NORM_EPS ); /*! requires - eps > 0 ensures - #get_learning_rate_multiplier() == 1 - #get_weight_decay_multiplier() == 0 - #get_bias_learning_rate_multiplier() == 1 - #get_bias_weight_decay_multiplier() == 1 - #get_eps() == eps_ !*/ float get_eps( ) const; /*! ensures - When doing RMS normalization, we are dividing by the root mean square. This epsilon value returned by this function is added to the mean square to prevent division by zero. !*/ void set_eps( float val ); /*! requires - val > 0 ensures - #get_eps() == val !*/ double get_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the learning rate used to optimize its parameters be multiplied by get_learning_rate_multiplier(). !*/ double get_weight_decay_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the weight decay used to optimize its parameters be multiplied by get_weight_decay_multiplier(). !*/ void set_learning_rate_multiplier( double val ); /*! requires - val >= 0 ensures - #get_learning_rate_multiplier() == val !*/ void set_weight_decay_multiplier( double val ); /*! requires - val >= 0 ensures - #get_weight_decay_multiplier() == val !*/ double get_bias_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the learning rate used to optimize its bias parameters be multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier(). !*/ double get_bias_weight_decay_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the weight decay used to optimize its bias parameters be multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier(). !*/ void set_bias_learning_rate_multiplier( double val ); /*! requires - val >= 0 ensures - #get_bias_learning_rate_multiplier() == val !*/ void set_bias_weight_decay_multiplier( double val ); /*! requires - val >= 0 ensures - #get_bias_weight_decay_multiplier() == val !*/ template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template using rms_norm = add_layer; // ---------------------------------------------------------------------------------------- enum layer_mode { CONV_MODE = 0, // convolutional mode FC_MODE = 1 // fully connected mode }; const double DEFAULT_BATCH_NORM_EPS = 0.0001; template < layer_mode mode > class bn_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a batch normalization layer that implements the method described in the paper: Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift by Sergey Ioffe and Christian Szegedy In particular, this layer produces output tensors with the same dimensionality as the input tensors, except that the mean and variances of the elements have been standardized to 0 and 1 respectively. It should also be noted that when tensors with a num_samples() dimension of 1 are passed to this layer it doesn't perform batch normalization. Instead, it runs in "inference mode" where the learned linear normalizing transformation is used to transform the tensor. Finally, after you finish training a batch normalized network, it is a good idea to replace each bn_ layer with an affine_ layer because the affine_ layer is faster and will never surprise you by performing batch normalization on tensors that have a num_samples() dimension > 1. This allows you to run large mini-batches of samples through your final network without batch normalization executing at all. !*/ public: bn_( ); /*! ensures - #get_mode() == mode - #get_running_stats_window_size() == 100 - #get_learning_rate_multiplier() == 1 - #get_weight_decay_multiplier() == 0 - #get_bias_learning_rate_multiplier() == 1 - #get_bias_weight_decay_multiplier() == 1 - #get_eps() == tt::DEFAULT_BATCH_NORM_EPS !*/ explicit bn_( unsigned long window_size, double eps = tt::DEFAULT_BATCH_NORM_EPS ); /*! requires - eps > 0 - window_size > 0 ensures - #get_mode() == mode - #get_running_stats_window_size() == window_size - #get_learning_rate_multiplier() == 1 - #get_weight_decay_multiplier() == 0 - #get_bias_learning_rate_multiplier() == 1 - #get_bias_weight_decay_multiplier() == 1 - #get_eps() == eps !*/ layer_mode get_mode( ) const; /*! ensures - returns the mode of this layer, either CONV_MODE or FC_MODE. If the mode is FC_MODE then the normalization is applied across the samples in a tensor (i.e. k()*nr()*nc() different things will be normalized). Otherwise, normalization is applied across everything except for the k() dimension, resulting in there being only k() normalization equations that are applied spatially over the tensor. Therefore, if you are putting batch normalization after a fully connected layer you should use FC_MODE. Otherwise, if you are putting batch normalization after a convolutional layer you should use CONV_MODE. !*/ double get_eps( ) const; /*! ensures - When doing batch normalization, we are dividing by the standard deviation. This epsilon value returned by this function is added to the variance to prevent the division from dividing by zero. !*/ unsigned long get_running_stats_window_size ( ) const; /*! ensures - Just as recommended in the batch normalization paper, this object keeps a running average of the mean and standard deviations of the features. These averages are used during "inference mode" so you can run a single object through a batch normalized network. They are also what is used to initialize an affine_ layer that is constructed from a bn_ layer. This function returns the effective number of recent samples used to compute the running average. !*/ void set_running_stats_window_size ( unsigned long new_window_size ); /*! requires - new_window_size > 0 ensures - #get_running_stats_window_size() == new_window_size !*/ double get_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the learning rate used to optimize its parameters be multiplied by get_learning_rate_multiplier(). !*/ double get_weight_decay_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the weight decay used to optimize its parameters be multiplied by get_weight_decay_multiplier(). !*/ void set_learning_rate_multiplier( double val ); /*! requires - val >= 0 ensures - #get_learning_rate_multiplier() == val !*/ void set_weight_decay_multiplier( double val ); /*! requires - val >= 0 ensures - #get_weight_decay_multiplier() == val !*/ double get_bias_learning_rate_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the learning rate used to optimize its bias parameters be multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier(). !*/ double get_bias_weight_decay_multiplier( ) const; /*! ensures - returns a multiplier number. The interpretation is that this object is requesting that the weight decay used to optimize its bias parameters be multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier(). !*/ void set_bias_learning_rate_multiplier( double val ); /*! requires - val >= 0 ensures - #get_bias_learning_rate_multiplier() == val !*/ void set_bias_weight_decay_multiplier( double val ); /*! requires - val >= 0 ensures - #get_bias_weight_decay_multiplier() == val !*/ template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template using bn_con = add_layer, SUBNET>; template using bn_fc = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- class affine_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it applies a simple pointwise linear transformation to an input tensor. You can think of it as having two parameter tensors, gamma and beta. If the input tensor is called INPUT then the output of this layer is: gamma*INPUT+beta where all operations are performed element wise and each sample in the INPUT tensor is processed separately. Moreover, this object has two modes that affect the dimensionalities of gamma and beta and how they are applied to compute gamma*INPUT+beta. If get_mode()==FC_MODE then gamma and beta each have the same dimensionality as the input tensor, except their num_samples() dimensions are 1. If get_mode()==CONV_MODE then gamma and beta have all their dimensions set to 1 except for k(), which is equal to INPUT.k(). In either case, the computation of gamma*INPUT+beta is performed pointwise over all the elements of INPUT using either: OUTPUT(n,k,r,c) == gamma(1,k,r,c)*INPUT(n,k,r,c)+beta(1,k,r,c) or OUTPUT(n,k,r,c) == gamma(1,k,1,1)*INPUT(n,k,r,c)+beta(1,k,1,1) as appropriate. Finally, note that the parameters of this layer are not learnable and therefore not modified during network updates. Instead, the layer will perform the identity transformation unless it is initialized with a bn_ layer, in which case it will perform whatever transformation the bn_ layer has learned. !*/ public: affine_( ); /*! ensures - #get_mode() == FC_MODE !*/ affine_( layer_mode mode ); /*! ensures - #get_mode() == mode !*/ template < layer_mode mode > affine_( const bn_& layer ); /*! ensures - Constructs affine_ so that it performs the same transformation as the supplied batch normalization layer. You would want to do this after you finish training a network with bn_ layers because the affine_ layer will execute faster. - #get_mode() == layer.get_mode() !*/ layer_mode get_mode( ) const; /*! ensures - returns the mode of this layer, either CONV_MODE or FC_MODE. !*/ void disable( ); /*! ensures - #get_layer_params().size() == 0. - when forward_inplace and backward_inplace are called, they return immediately doing nothing. Causing this layer to trivially perform the an identity transform. !*/ alias_tensor_instance get_gamma(); /*! ensures - returns the gamma parameter that defines the behavior of forward(). !*/ alias_tensor_const_instance get_gamma() const; /*! ensures - returns the gamma parameter that defines the behavior of forward(). !*/ alias_tensor_instance get_beta(); /*! ensures - returns the beta parameter that defines the behavior of forward(). !*/ alias_tensor_const_instance get_beta() const; /*! ensures - returns the beta parameter that defines the behavior of forward(). !*/ template void setup (const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Also note that get_layer_params() always returns an empty tensor since there are no learnable parameters in this object. !*/ }; template using affine = add_layer; // ---------------------------------------------------------------------------------------- template < long _nr, long _nc, int _stride_y, int _stride_x, int _padding_y = _stride_y!=1? 0 : _nr/2, int _padding_x = _stride_x!=1? 0 : _nc/2 > class max_pool_ { /*! REQUIREMENTS ON TEMPLATE ARGUMENTS - _nr >= 0 - _nc >= 0 - _stride_y > 0 - _stride_x > 0 - _padding_y >= 0 - _padding_x >= 0 - if (_nr != 0) then - _padding_y < _nr - else - _padding_y == 0 - if (_nc != 0) then - _padding_x < _nr - else - _padding_x == 0 WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a max pooling layer that takes an input tensor and downsamples it. It does this by sliding a window over the images in an input tensor and outputting, for each channel, the maximum element within the window. If _nr == 0 then it means the filter size covers all the rows in the input tensor, similarly for the _nc parameter. To be precise, if we call the input tensor IN and the output tensor OUT, then OUT is defined as follows: - let FILT_NR == (nr()==0) ? IN.nr() : nr() - let FILT_NC == (nc()==0) ? IN.nc() : nc() - OUT.num_samples() == IN.num_samples() - OUT.k() == IN.k() - OUT.nr() == 1+(IN.nr() + 2*padding_y() - FILT_NR)/stride_y() - OUT.nc() == 1+(IN.nc() + 2*padding_x() - FILT_NC)/stride_x() - for all valid s, k, r, and c: - image_plane(OUT,s,k)(r,c) == max(subm_clipped(image_plane(IN,s,k), centered_rect(x*stride_x() + FILT_NC/2 - padding_x(), y*stride_y() + FILT_NR/2 - padding_y(), FILT_NC, FILT_NR))) !*/ public: max_pool_ ( ); /*! ensures - #nr() == _nr - #nc() == _nc - #stride_y() == _stride_y - #stride_x() == _stride_x - #padding_y() == _padding_y - #padding_x() == _padding_x !*/ long nr( ) const; /*! ensures - returns the number of rows in the pooling window or 0 if the window size is "the entire input tensor". !*/ long nc( ) const; /*! ensures - returns the number of rows in the pooling window or 0 if the window size is "the entire input tensor". !*/ long stride_y( ) const; /*! ensures - returns the vertical stride used when scanning the max pooling window over an image. That is, each window will be moved stride_y() pixels down at a time when it moves over the image. !*/ long stride_x( ) const; /*! ensures - returns the horizontal stride used when scanning the max pooling window over an image. That is, each window will be moved stride_x() pixels down at a time when it moves over the image. !*/ long padding_y( ) const; /*! ensures - returns the number of pixels of zero padding added to the top and bottom sides of the image. !*/ long padding_x( ) const; /*! ensures - returns the number of pixels of zero padding added to the left and right sides of the image. !*/ template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template < long nr, long nc, int stride_y, int stride_x, typename SUBNET > using max_pool = add_layer, SUBNET>; template < typename SUBNET > using max_pool_everything = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < long _nr, long _nc, int _stride_y, int _stride_x, int _padding_y = _stride_y!=1? 0 : _nr/2, int _padding_x = _stride_x!=1? 0 : _nc/2 > class avg_pool_ { /*! REQUIREMENTS ON TEMPLATE ARGUMENTS - _nr >= 0 - _nc >= 0 - _stride_y > 0 - _stride_x > 0 - _padding_y >= 0 - _padding_x >= 0 - if (_nr != 0) then - _padding_y < _nr - else - _padding_y == 0 - if (_nc != 0) then - _padding_x < _nr - else - _padding_x == 0 WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines an average pooling layer that takes an input tensor and downsamples it. It does this by sliding a window over the images in an input tensor and outputting, for each channel, the average element within the window. If _nr == 0 then it means the filter size covers all the rows in the input tensor, similarly for the _nc parameter. To be precise, if we call the input tensor IN and the output tensor OUT, then OUT is defined as follows: - let FILT_NR == (nr()==0) ? IN.nr() : nr() - let FILT_NC == (nc()==0) ? IN.nc() : nc() - OUT.num_samples() == IN.num_samples() - OUT.k() == IN.k() - OUT.nr() == 1+(IN.nr() + 2*padding_y() - FILT_NR)/stride_y() - OUT.nc() == 1+(IN.nc() + 2*padding_x() - FILT_NC)/stride_x() - for all valid s, k, r, and c: - image_plane(OUT,s,k)(r,c) == mean(subm_clipped(image_plane(IN,s,k), centered_rect(x*stride_x() + FILT_NC/2 - padding_x(), y*stride_y() + FILT_NR/2 - padding_y(), FILT_NC, FILT_NR))) !*/ public: avg_pool_ ( ); /*! ensures - #nr() == _nr - #nc() == _nc - #stride_y() == _stride_y - #stride_x() == _stride_x - #padding_y() == _padding_y - #padding_x() == _padding_x !*/ long nr( ) const; /*! ensures - returns the number of rows in the pooling window or 0 if the window size is "the entire input tensor". !*/ long nc( ) const; /*! ensures - returns the number of rows in the pooling window or 0 if the window size is "the entire input tensor". !*/ long stride_y( ) const; /*! ensures - returns the vertical stride used when scanning the pooling window over an image. That is, each window will be moved stride_y() pixels down at a time when it moves over the image. !*/ long stride_x( ) const; /*! ensures - returns the horizontal stride used when scanning the pooling window over an image. That is, each window will be moved stride_x() pixels down at a time when it moves over the image. !*/ long padding_y( ) const; /*! ensures - returns the number of pixels of zero padding added to the top and bottom sides of the image. !*/ long padding_x( ) const; /*! ensures - returns the number of pixels of zero padding added to the left and right sides of the image. !*/ template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template < long nr, long nc, int stride_y, int stride_x, typename SUBNET > using avg_pool = add_layer, SUBNET>; template < typename SUBNET > using avg_pool_everything = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- class relu_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a rectified linear layer. Therefore, it passes its inputs through the function f(x)=max(x,0) where f() is applied pointwise across the input tensor. !*/ public: relu_( ); void disable( ); /*! ensures - #get_layer_params().size() == 0. - when forward_inplace and backward_inplace are called, they return immediately doing nothing. Causing this layer to trivially perform the an identity transform. !*/ template void setup (const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template using relu = add_layer; // ---------------------------------------------------------------------------------------- class prelu_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a parametric rectified linear layer. Therefore, it passes its inputs through the function f(x) = x>0 ? x : p*x where f() is applied pointwise across the input tensor and p is a scalar parameter learned by this layer. This is the layer type introduced in the paper: He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE International Conference on Computer Vision. 2015. !*/ public: explicit prelu_( float initial_param_value = 0.25 ); /*! ensures - The p parameter will be initialized with initial_param_value. - #get_initial_param_value() == initial_param_value. !*/ float get_initial_param_value ( ) const; /*! ensures - returns the initial value of the prelu parameter. !*/ template void setup (const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template using prelu = add_layer; // ---------------------------------------------------------------------------------------- class leaky_relu_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a leaky rectified linear layer. Therefore, it passes its inputs through the function f(x) = x>0 ? x : alpha*x where f() is applied pointwise across the input tensor and alpha is a non-learned scalar. This is the layer type introduced in the paper: A. L. Maas, A. Y. Hannun, and A. Y. Ng. "Rectifier nonlinearities improve neural network acoustic models". In ICML, 2013. !*/ public: explicit leaky_relu_( float alpha = 0.01f ); /*! ensures - the alpha parameter will be initialized with the alpha value !*/ float get_alpha( ) const; /*! ensures - returns the alpha parameter of the leaky_relu !*/ template void setup(const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template using leaky_relu = add_layer; // ---------------------------------------------------------------------------------------- class sig_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a sigmoid layer. Therefore, it passes its inputs through the function f(x)=1/(1+exp(-x)) where f() is applied pointwise across the input tensor. !*/ public: sig_( ); template void setup (const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template using sig = add_layer; // ---------------------------------------------------------------------------------------- class mish_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a mish layer. Therefore, it passes its inputs through the function f(x)= x*tanh(log(1+exp(x))) where f() is applied pointwise across the input tensor. This is the layer type introduced in the paper: Diganta Misra. "Mish: A Self Regularized Non-Monotonic Activation Function" !*/ public: mish_( ); template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& data_output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor&); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template using mish = add_layer; // ---------------------------------------------------------------------------------------- class htan_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a hyperbolic tangent layer. Therefore, it passes its inputs through the function f(x)=std::tanh(x) where f() is applied pointwise across the input tensor. !*/ public: htan_( ); template void setup (const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template using htan = add_layer; // ---------------------------------------------------------------------------------------- class clipped_relu_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a clipped version of the relu layer. Therefore, it passes its inputs through the function f(x) = min(max(x, 0), ceiling) where f() is applied pointwise across the input tensor and ceiling is a non-learned scalar. !*/ public: clipped_relu_( const float ceiling = 6.0f ); /*! ensures - the ceiling parameter will be initialized with the ceiling value !*/ float get_ceiling() const; /*! ensures - returns the celiling parameter of the clipped_relu !*/ template void setup (const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template using clipped_relu = add_layer; // ---------------------------------------------------------------------------------------- class elu_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines an exponential linear unit. Therefore, it passes its inputs through the function f(x) = x>0 ? x : alpha*(exp(x)-1) where f() is applied pointwise across the input tensor and alpha is a non-learned scalar. This is the layer type introduced in the paper: Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter. "Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)". !*/ public: elu_( const float alpha = 1.0f ); /*! ensures - the alpha parameter will be initialized with the alpha value !*/ float get_alpha() const; /*! ensures - returns the alpha parameter of the elu !*/ template void setup (const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template using elu = add_layer; // ---------------------------------------------------------------------------------------- class gelu_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a gelu layer. Therefore, it passes its inputs through the function f(x)= x/2 * (1 + erf(x/sqrt(2)) where f() is applied pointwise across the input tensor. This is the layer type introduced in the paper: Dan Hendrycks, Kevin Gimpel. "Gaussian Error Linear Units (GELUs)". !*/ public: gelu_( ); template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& data_output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor&); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template using gelu = add_layer; // ---------------------------------------------------------------------------------------- class smelu_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a smooth rectified linear layer. Therefore, it passes its inputs through the function f(x): - if (x > beta) 1 - if (x < -beta) 0 - else std::pow(x + beta, 2) / (4 * beta) where f() is applied pointwise across the input tensor and beta is a non-learned scalar. This is the layer type introduced in the paper: "Smooth activations and reproducibility in deep networks" by Gil I. Shamir, Dong Lin, Lorenzo Coviello (https://arxiv.org/abs/2010.09931) !*/ public: explicit smelu_( float beta = 1 ); /*! ensures - the beta parameter will be initialized with the beta value !*/ float get_beta( ) const; /*! ensures - returns the beta parameter of the smelu !*/ template void setup(const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template using smelu = add_layer; // ---------------------------------------------------------------------------------------- class silu_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a silu layer. Therefore, it passes its inputs through the function f(x)= x * sigmoid(x) = x / (1 + exp(-x)) where f() is applied pointwise across the input tensor. This is the layer type introduced in the paper: Dan Hendrycks, Kevin Gimpel. "Gaussian Error Linear Units (GELUs)". !*/ public: silu_( ); template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& data_output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor&); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template using silu = add_layer; // ---------------------------------------------------------------------------------------- template class softmax_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. It defines a softmax layer with two modes of operation: channel-wise and plane-wise. The softmax function s(x) is defined as: s(x) == exp(x)/sum(exp(x)) where x is a vector. 1. Channel-wise mode (s_mode_ == CHANNEL_WISE): This mode treats the input tensor as a collection of multi-channel images and applies s() to each spatial location in each image. The tensor::k() channel elements at each position are input to s() and then replaced by the outputs of s(). 2. Plane-wise mode (s_mode_ == PLANE_WISE): This mode applies the softmax function across entire planes (nr x nc) of the input tensor, useful for operations in Large Language Models (LLMs) and other applications requiring 2D tensor processing. In both modes, the sum of the outputs of s() will always be equal to 1 for each application of the function. TEMPLATE PARAMETERS - s_mode_: Determines the mode of operation (CHANNEL_WISE or PLANE_WISE) !*/ public: softmax_(); template void setup(const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace( const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad ); const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ friend void serialize(const softmax_& item, std::ostream& out); friend void deserialize(softmax_& item, std::istream& in); friend std::ostream& operator<<(std::ostream& out, const softmax_& item); friend void to_xml(const softmax_& item, std::ostream& out); }; template using softmax = add_layer, SUBNET>; template using softmaxm = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- class softmax_all_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, it defines a softmax layer. To be precise, we define the softmax function s(x) as: s(x) == exp(x)/sum(exp(x)) where x is a vector. Then this layer treats its input tensor as a collection of tensor::num_samples() vectors and applies s() to each vector in the tensor. Therefore, there are logically tensor::num_samples() invocations of s(). !*/ public: softmax_all_( ); template void setup (const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ }; template using softmax_all = add_layer; // ---------------------------------------------------------------------------------------- template < template class tag > class add_prev_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. This layer simply adds the output of two previous layers. In particular, it adds the tensor from its immediate predecessor layer, sub.get_output(), with the tensor from a deeper layer, layer(sub).get_output(). Therefore, you supply a tag via add_prev_'s template argument that tells it what layer to add to the output of the previous layer. The result of this addition is output by add_prev_. Finally, the addition happens pointwise according to 4D tensor arithmetic. If the dimensions don't match then missing elements are presumed to be equal to 0. Moreover, each dimension of the output tensor is equal to the maximum dimension of either of the inputs. That is, if the tensors A and B are being added to produce C then: - C.num_samples() == max(A.num_samples(), B.num_samples()) - C.k() == max(A.k(), B.k()) - C.nr() == max(A.nr(), B.nr()) - C.nc() == max(A.nc(), B.nc()) !*/ public: add_prev_( ); template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < template class tag, typename SUBNET > using add_prev = add_layer, SUBNET>; // Here we add some convenient aliases for using add_prev_ with the tag layers. template using add_prev1 = add_prev; template using add_prev2 = add_prev; template using add_prev3 = add_prev; template using add_prev4 = add_prev; template using add_prev5 = add_prev; template using add_prev6 = add_prev; template using add_prev7 = add_prev; template using add_prev8 = add_prev; template using add_prev9 = add_prev; template using add_prev10 = add_prev; using add_prev1_ = add_prev_; using add_prev2_ = add_prev_; using add_prev3_ = add_prev_; using add_prev4_ = add_prev_; using add_prev5_ = add_prev_; using add_prev6_ = add_prev_; using add_prev7_ = add_prev_; using add_prev8_ = add_prev_; using add_prev9_ = add_prev_; using add_prev10_ = add_prev_; // ---------------------------------------------------------------------------------------- template < template class tag > class mult_prev_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. This layer simply multiplies the output of two previous layers. In particular, it multiplies the tensor from its immediate predecessor layer, sub.get_output(), with the tensor from a deeper layer, layer(sub).get_output(). Therefore, you supply a tag via mult_prev_'s template argument that tells it what layer to multiply with the output of the previous layer. The result of this multiplication is output by mult_prev_. Finally, the multiplication happens pointwise according to 4D tensor arithmetic. If the dimensions don't match then missing elements are presumed to be equal to 0. Moreover, each dimension of the output tensor is equal to the maximum dimension of either of the inputs. That is, if the tensors A and B are being multiplied to produce C then: - C.num_samples() == max(A.num_samples(), B.num_samples()) - C.k() == max(A.k(), B.k()) - C.nr() == max(A.nr(), B.nr()) - C.nc() == max(A.nc(), B.nc()) !*/ public: mult_prev_( ); template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < template class tag, typename SUBNET > using mult_prev = add_layer, SUBNET>; // Here we add some convenient aliases for using mult_prev_ with the tag layers. template using mult_prev1 = mult_prev; template using mult_prev2 = mult_prev; template using mult_prev3 = mult_prev; template using mult_prev4 = mult_prev; template using mult_prev5 = mult_prev; template using mult_prev6 = mult_prev; template using mult_prev7 = mult_prev; template using mult_prev8 = mult_prev; template using mult_prev9 = mult_prev; template using mult_prev10 = mult_prev; using mult_prev1_ = mult_prev_; using mult_prev2_ = mult_prev_; using mult_prev3_ = mult_prev_; using mult_prev4_ = mult_prev_; using mult_prev5_ = mult_prev_; using mult_prev6_ = mult_prev_; using mult_prev7_ = mult_prev_; using mult_prev8_ = mult_prev_; using mult_prev9_ = mult_prev_; using mult_prev10_ = mult_prev_; // ---------------------------------------------------------------------------------------- template < template class tag > class multm_prev_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. This layer performs matrix multiplication on the output of two previous layers. It multiplies the tensor from its immediate predecessor layer, sub.get_output(), with the tensor from a deeper layer, layer(sub).get_output(). The tag template argument specifies which layer to multiply with the output of the previous layer. The result of this multiplication is output by multm_prev_. The multiplication is performed using a modified version of gemm() to account for the 2D matrix dimension in the nr()xnc() planes of Dlib's 4D tensors. This layer is similar to mult_prev_, but it considers the full matrix in the nr()xnc() planes of the tensor, rather than just the upper num_samples()xk() plane. This makes it suitable for implementing mechanisms like attention, especially when the k() channel plane is used to model multiple heads for parallel matrix processing. The output tensor dimensions are determined as follows: - output.num_samples() == t1.num_samples() - output.k() == t1.k() - output.nr() == t1.nr() - output.nc() == t2.nc() where t1 is sub.get_output() and t2 is layer(sub).get_output(). !*/ public: multm_prev_( ); template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < template class tag, typename SUBNET > using multm_prev = add_layer, SUBNET>; // Here we add some convenient aliases for using multm_prev_ with the tag layers. template using multm_prev1 = multm_prev; template using multm_prev2 = multm_prev; template using multm_prev3 = multm_prev; template using multm_prev4 = multm_prev; template using multm_prev5 = multm_prev; template using multm_prev6 = multm_prev; template using multm_prev7 = multm_prev; template using multm_prev8 = multm_prev; template using multm_prev9 = multm_prev; template using multm_prev10 = multm_prev; using multm_prev1_ = multm_prev_; using multm_prev2_ = multm_prev_; using multm_prev3_ = multm_prev_; using multm_prev4_ = multm_prev_; using multm_prev5_ = multm_prev_; using multm_prev6_ = multm_prev_; using multm_prev7_ = multm_prev_; using multm_prev8_ = multm_prev_; using multm_prev9_ = multm_prev_; using multm_prev10_ = multm_prev_; // ---------------------------------------------------------------------------------------- template < template class tag > class resize_prev_to_tagged_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. This layer resizes the output channels of the previous layer to have the same number of rows and columns as the output of the tagged layer. This layer uses bilinear interpolation. If the sizes match already, then it simply copies the data. Therefore, you supply a tag via resize_prev_to_tagged's template argument that tells it what layer to use for the target size. If tensor PREV is resized to size of tensor TAGGED, then a tensor OUT is produced such that: - OUT.num_samples() == PREV.num_samples() - OUT.k() == PREV.k() - OUT.nr() == TAGGED.nr() - OUT.nc() == TAGGED.nc() !*/ public: resize_prev_to_tagged_( ); template void setup(const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < template class tag, typename SUBNET > using resize_prev_to_tagged = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < template class tag > class scale_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. This layer scales the output channels of the tagged layer by multiplying it with the output of the previous layer. To be specific: - Let INPUT == layer(sub).get_output() - Let SCALES == sub.get_output() - This layer takes INPUT and SCALES as input. - The output of this layer has the same dimensions as INPUT. - This layer requires: - SCALES.num_samples() == INPUT.num_samples() - SCALES.k() == INPUT.k() - SCALES.nr() == 1 - SCALES.nc() == 1 - The output tensor is produced by pointwise multiplying SCALES with INPUT at each spatial location. Therefore, if OUT is the output of this layer then we would have: OUT(n,k,r,c) == INPUT(n,k,r,c)*SCALES(n,k) !*/ public: scale_( ); template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < template class tag, typename SUBNET > using scale = add_layer, SUBNET>; // Here we add some convenient aliases for using scale_ with the tag layers. template using scale1 = scale; template using scale2 = scale; template using scale3 = scale; template using scale4 = scale; template using scale5 = scale; template using scale6 = scale; template using scale7 = scale; template using scale8 = scale; template using scale9 = scale; template using scale10 = scale; using scale1_ = scale_; using scale2_ = scale_; using scale3_ = scale_; using scale4_ = scale_; using scale5_ = scale_; using scale6_ = scale_; using scale7_ = scale_; using scale8_ = scale_; using scale9_ = scale_; using scale10_ = scale_; // ---------------------------------------------------------------------------------------- template < template class tag > class scale_prev_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. This layer scales the output channels of the tagged layer by multiplying it with the output of the previous layer. It is excatly the same as the scale_ layer, but with the inputs swapped, which is useful since it allows mapping between inputs and outputs of this layer. To be specific: - Let INPUT == sub.get_output() - Let SCALES == layer(sub).get_output() - This layer takes INPUT and SCALES as input. - The output of this layer has the same dimensions as INPUT. - This layer requires: - SCALES.num_samples() == INPUT.num_samples() - SCALES.k() == INPUT.k() - SCALES.nr() == 1 - SCALES.nc() == 1 - The output tensor is produced by pointwise multiplying SCALES with INPUT at each spatial location. Therefore, if OUT is the output of this layer then we would have: OUT(n,k,r,c) == INPUT(n,k,r,c)*SCALES(n,k) !*/ public: scale_prev_( ); template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < template class tag, typename SUBNET > using scale_prev = add_layer, SUBNET>; // Here we add some convenient aliases for using scale_prev_ with the tag layers. template using scale_prev1 = scale_prev; template using scale_prev2 = scale_prev; template using scale_prev3 = scale_prev; template using scale_prev4 = scale_prev; template using scale_prev5 = scale_prev; template using scale_prev6 = scale_prev; template using scale_prev7 = scale_prev; template using scale_prev8 = scale_prev; template using scale_prev9 = scale_prev; template using scale_prev10 = scale_prev; using scale_prev1_ = scale_prev_; using scale_prev2_ = scale_prev_; using scale_prev3_ = scale_prev_; using scale_prev4_ = scale_prev_; using scale_prev5_ = scale_prev_; using scale_prev6_ = scale_prev_; using scale_prev7_ = scale_prev_; using scale_prev8_ = scale_prev_; using scale_prev9_ = scale_prev_; using scale_prev10_ = scale_prev_; // ---------------------------------------------------------------------------------------- template< template class... TAG_TYPES > class concat_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. This layer simply concatenates the output of tagged layers. Importantly, each input layer must have the same dimensions (i.e. num_samples, nr, and nc) except for the k channel, which may vary. This is because the concatenation happens along the k dimension. That is, the output of this network is a tensor, OUT, that is the concatenation of the tensors: for each (tag in TAG_TYPES) layer(subnet).get_output() Therefore, out.num_samples(), out.nr(), and out.nc() match the dimensions of the input tensors while OUT.k() is the sum of the input layer's k() dimensions. !*/ public: template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; // concat layer definitions template class TAG1, template class TAG2, typename SUBNET> using concat2 = add_layer, SUBNET>; template class TAG1, template class TAG2, template class TAG3, typename SUBNET> using concat3 = add_layer, SUBNET>; template class TAG1, template class TAG2, template class TAG3, template class TAG4, typename SUBNET> using concat4 = add_layer, SUBNET>; template class TAG1, template class TAG2, template class TAG3, template class TAG4, template class TAG5, typename SUBNET> using concat5 = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- /*!A inception layer definitions !*/ // Now define inception layer tag types. These layer aliases allow creating // the networks described in the paper: // Szegedy, Christian, et al. "Going deeper with convolutions." Proceedings of // the IEEE Conference on Computer Vision and Pattern Recognition. 2015. // See the dnn_inception_ex.cpp example for a complete example of their use. Note also // that we use tag ID numbers >= 1000 to avoid conflict with user's tag layers. template using itag0 = add_tag_layer< 1000 + 0, SUBNET>; template using itag1 = add_tag_layer< 1000 + 1, SUBNET>; template using itag2 = add_tag_layer< 1000 + 2, SUBNET>; template using itag3 = add_tag_layer< 1000 + 3, SUBNET>; template using itag4 = add_tag_layer< 1000 + 4, SUBNET>; template using itag5 = add_tag_layer< 1000 + 5, SUBNET>; // skip to inception input template using iskip = add_skip_layer< itag0, SUBNET>; // here are some templates to be used for creating inception layer groups template class B1, templateclass B2, typename SUBNET> using inception2 = concat2>>>>>>; template class B1, templateclass B2, templateclass B3, typename SUBNET> using inception3 = concat3>>>>>>>>>; template class B1, templateclass B2, templateclass B3, templateclass B4, typename SUBNET> using inception4 = concat4>>>>>>>>>>>>; template class B1, templateclass B2, templateclass B3, templateclass B4, templateclass B5, typename SUBNET> using inception5 = concat5>>>>>>>>>>>>>>>; // ---------------------------------------------------------------------------------------- const double DEFAULT_L2_NORM_EPS = 1e-5; class l2normalize_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. It takes tensors as input and L2 normalizes them. In particular, it has the following properties: - The output tensors from this layer have the same dimensions as the input tensors. - If you think of each input tensor as a set of tensor::num_samples() vectors, then the output tensor contains the same vectors except they have been length normalized so that their L2 norms are all 1. I.e. for each vector v we will have ||v||==1. !*/ public: explicit l2normalize_( double eps = tt::DEFAULT_L2_NORM_EPS ); /*! requires - eps > 0 ensures - #get_eps() == eps !*/ double get_eps( ) const; /*! ensures - When we normalize a vector we divide it by its L2 norm. However, the get_eps() value is added to the squared norm prior to division to avoid ever dividing by zero. !*/ template void setup (const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; // ---------------------------------------------------------------------------------------- template < long _offset, long _k, long _nr, long _nc > class extract_ { /*! REQUIREMENTS ON TEMPLATE ARGUMENTS - 0 <= _offset - 0 < _k - 0 < _nr - 0 < _nc WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, the output of this layer is simply a copy of the input tensor. However, you can configure the extract layer to output only some subset of the input tensor and also to reshape it. Therefore, the dimensions of the tensor output by this layer are as follows (letting IN be the input tensor and OUT the output tensor): - OUT.num_samples() == IN.num_samples() - OUT.k() == _k - OUT.nr() == _nr - OUT.nc() == _nc So the output will always have the same number of samples as the input, but within each sample (the k,nr,nc part) we will copy only a subset of the values. Moreover, the _offset parameter controls which part of each sample we take. To be very precise, we will have: - let IN_SIZE = IN.k()*IN.nr()*IN.nc() - let OUT_SIZE = _k*_nr*_nc - for i in range[0,IN.num_samples()) and j in range[0,OUT_SIZE): - OUT.host()[i*OUT_SIZE+j] == IN.host()[i*IN_SIZE+_offset+j] Finally, all this means that the input tensor to this layer must have a big enough size to accommodate taking a _k*_nr*_nc slice from each of its samples. !*/ public: template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < long offset, long k, long nr, long nc, typename SUBNET > using extract = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template < long _offset_k, long _offset_nr, long _offset_nc, long _k, long _nr, long _nc > class slice_ { /*! REQUIREMENTS ON TEMPLATE ARGUMENTS - 0 <= _offset_k - 0 <= _offset_nr - 0 <= _offset_nc - 0 < _k - 0 < _nr - 0 < _nc WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, the output of this layer is simply a copy of the input tensor. It is similar to extract in that you can configure the slice layer to output only some subset of the input tensor, but slice allows copies of non-contiguous regions of the input which enables three dimensional cropping of a tensor. The dimensions of the tensor output by this layer are as follows (letting IN be the input tensor and OUT the output tensor): - OUT.num_samples() == IN.num_samples() - OUT.k() == _k - OUT.nr() == _nr - OUT.nc() == _nc So the output will always have the same number of samples as the input, but within each sample (the k,nr,nc part) we will copy only a subset of the values. Moreover, the _offset_k, _offset_nr, and _offset_nc parameters control which channels, rows, and columns of each sample we take. To be very precise, we will have: - let IN_SIZE = IN.k()*IN.nr()*IN.nc() - let OUT_SIZE = _k*_nr*_nc - for i in range[0,IN.num_samples()) and j in range[0,OUT_SIZE): - let k = (j / (OUT.nr()*OUT.nc())) % OUT.k() - let r = (j / OUT.nc()) % IN.nr() - let c = j % OUT.nc() - OUT.host()[i*OUT_SIZE+j] == IN.host()[i*IN_SIZE+ k_stride*(_offset_k+k)+ row_stride*(_offset_nr+r)+ col_stride*(_offset_nc+c)] Finally, all this means that the input tensor to this layer must have a big enough size to accommodate taking a _k*_nr*_nc slice from each of its samples. !*/ public: template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < long offset_k, long offset_nr, long offset_nc, long k, long nr, long nc, typename SUBNET > using slice = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template class reorg_ { /*! REQUIREMENTS ON TEMPLATE ARGUMENTS - row_stride >= 1 - col_stride >= 1 WHAT THIS OBJECT REPRESENTS This class implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface, performing a reorganization of tensor data. It rearranges spatial information along the channel dimension, effectively "folding" spatial dimensions into channels. The dimensions of the output tensor are as follows (letting IN be the input tensor and OUT the output tensor): - OUT.num_samples() == IN.num_samples() - OUT.k() == IN.k() * row_stride * col_stride - OUT.nr() == IN.nr() / row_stride - OUT.nc() == IN.nc() / col_stride Therefore, the output tensor maintains the same number of samples as the input but alters the channel and spatial dimensions based on the specified strides. Specifically, for all n, k, r, c in OUT: OUT.host[tensor_index(OUT, n, k, r, c)] == IN.host[tensor_index(IN, n, k % IN.k(), r * row_stride + (k / IN.k()) / col_stride, c * col_stride + (k / IN.k()) % col_stride)] **Enhancement Note:** The underlying utility functions (`reorg` and `reorg_gradient`) now include an optional `bool add_to` parameter. While the current implementation uses the default value to maintain existing behavior, this parameter allows for future reversible operations and gradient accumulation flexibility within neural network layers. You can think of this layer as an alternative to a strided convolutional layer for downsampling tensors, offering similar spatial reduction with different internal gradient propagation mechanics. !*/ public: template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output (dpoint p) const; dpoint map_output_to_input (dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template using reorg = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- class transpose_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. In particular, this layer performs a 2D matrix transposition on each of the k planes within each sample of a 4D tensor. The dimensions of the tensor output by this layer are as follows (letting IN be the input tensor and OUT the output tensor): - OUT.num_samples() == IN.num_samples() - OUT.k() == IN.k() - OUT.nr() == IN.nc() - OUT.nc() == IN.nr() The transposition is performed as follows: - For each sample i and each k-plane j: - OUT[i][j][r][c] = IN[i][j][c][r] for all r in [0, IN.nc()) and c in [0, IN.nr()) This layer does not have any learnable parameters. !*/ public: transpose_() = default; template void setup (const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); inline dpoint map_input_to_output(dpoint p) const; inline dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); friend void serialize(const transpose_& item, std::ostream& out); friend void deserialize(transpose_& item, std::istream& in); friend std::ostream& operator<<(std::ostream& out, const transpose_& item); friend void to_xml(const transpose_& item, std::ostream& out); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ private: resizable_tensor params; // unused }; template using transpose = add_layer; // ---------------------------------------------------------------------------------------- class positional_encodings_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface. It defines a positional encoding layer that adds position information to the input tensor. This is particularly useful in transformer architectures where the order of the sequence matters. The dimensions of the tensors output by this layer are the same as the input tensor dimensions. This implementation is based on the positional encoding described in: Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. In Advances in neural information processing systems (pp. 5998-6008). The encoding uses sine and cosine functions of different frequencies: PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model)) where pos is the position and i is the dimension. !*/ public: positional_encodings_( unsigned long sequence_dim_ = 1, unsigned long embedding_dim_ = 1 ); /*! ensures - #sequence_dim == sequence_dim_ - #embedding_dim == embedding_dim_ !*/ positional_encodings_ ( const positional_encodings_& item ); /*! ensures - EXAMPLE_COMPUTATIONAL_LAYER_ objects are copy constructable !*/ positional_encodings_& operator=( const positional_encodings_& item ); /*! ensures - EXAMPLE_COMPUTATIONAL_LAYER_ objects are assignable !*/ template void setup ( const SUBNET& sub ); /*! requires - SUBNET implements the SUBNET interface defined at the top of this file. ensures - performs any necessary setup for the layer, including the calculation of positional encodings based on the dimensions of the input. !*/ template void forward( const SUBNET& sub, resizable_tensor& output ); /*! requires - SUBNET implements the SUBNET interface defined at the top of this file. - setup() has been called. ensures - Adds the positional encodings to the output of the subnetwork and stores the results into #output. !*/ template void backward( const tensor& gradient_input, SUBNET& sub, tensor& params_grad ); /*! requires - SUBNET implements the SUBNET interface defined at the top of this file. - setup() has been called. - #params_grad is unused in this layer as there are no learnable parameters. ensures - Computes the gradient of the layer with respect to the input, which is simply the input gradient itself as positional encodings are constant. !*/ const tensor& get_layer_params( ) const; /*! ensures - returns the parameters that define the behavior of forward(). Note: This layer has no learnable parameters, so this returns an empty tensor. !*/ tensor& get_layer_params( ); /*! ensures - returns the parameters that define the behavior of forward(). Note: This layer has no learnable parameters, so this returns an empty tensor. !*/ const tensor& get_positional_encodings( ) const; /*! ensures - returns the computed positional encodings. !*/ tensor& get_positional_encodings( ); /*! ensures - returns the computed positional encodings. !*/ friend void serialize(const positional_encodings_& item, std::ostream& out); friend void deserialize(positional_encodings_& item, std::istream& in); /*! provides serialization support !*/ friend std::ostream& operator<<(std::ostream& out, const positional_encodings_& item); /*! print a string describing this layer. !*/ friend void to_xml(const positional_encodings_& item, std::ostream& out); /*! This function is optional, but required if you want to print your networks with net_to_xml(). It prints a layer as XML. !*/ }; template using positional_encodings = add_layer; // ---------------------------------------------------------------------------------------- template < unsigned long num_embeddings_, unsigned long embedding_dim_ > class embeddings_ { /*! WHAT THIS OBJECT REPRESENTS This object represents an embedding layer in a neural network. It maps discrete tokens to continuous vector representations. This is a fundamental technique in natural language processing and other domains dealing with categorical data. The layer takes as input a tensor of integer indices and outputs a tensor of the same shape (except for the last dimension) where each index is replaced by its corresponding embedding vector. For more information on embeddings, see: Mikolov, T., Sutskever, I., Chen, K., Corrado, G. S., & Dean, J. (2013). Distributed representations of words and phrases and their compositionality. In Advances in neural information processing systems (pp. 3111-3119). TEMPLATE PARAMETERS - num_embeddings_: The size of the embedding dictionary, i.e., the number of discrete tokens that can be embedded. - embedding_dim_: The dimensionality of each embedding vector. CONVENTION - get_embeddings() returns the tensor of embedding vectors. - get_num_embeddings() == num_embeddings_ - get_embedding_dim() == embedding_dim_ - get_learning_rate_multiplier() returns the learning rate multiplier for this layer. - get_scale_by_freq() returns whether to scale gradients by token frequency. */ public: embeddings_() = default; unsigned long get_num_embeddings() const; unsigned long get_embedding_dim() const; double get_learning_rate_multiplier() const; bool get_scale_by_freq() const; void set_num_embeddings(unsigned long num); void set_embedding_dim(unsigned long dim); void set_learning_rate_multiplier(double val); void set_scale_by_freq(bool val); template void setup(const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); const tensor& get_layer_params() const; tensor& get_layer_params(); const tensor& get_embeddings() const; tensor& get_embeddings(); friend void serialize(const embeddings_& item, std::ostream& out); friend void deserialize(embeddings_& item, std::istream& in); friend std::ostream& operator<<(std::ostream& out, const embeddings_& item); friend void to_xml(const embeddings_& item, std::ostream& out); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. !*/ }; template < unsigned long num_embeddings, unsigned long embedding_dim, typename SUBNET > using embeddings = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- struct neg_infinity_tag {}; struct zero_tag {}; template struct is_special_value : std::false_type {}; template<> struct is_special_value : std::true_type {}; template<> struct is_special_value : std::true_type {}; template class tril_ { /*! TEMPLATE PARAMETERS - diag_: A long integer specifying the diagonal offset. - tag_: A type tag specifying special values or void for numeric values. - num_: Numerator for numeric diagonal value (default is 0, only used if tag_ is void). - den_: Denominator for numeric diagonal value (default is 1, only used if tag_ is void). REQUIREMENTS - diag_ must be an integer. - tag_ must be either neg_infinity_tag, zero_tag, or void. - If tag_ is void, num_ and den_ are used to compute the diagonal value. - If tag_ is neg_infinity_tag or zero_tag, num_ and den_ are ignored. WHAT THIS OBJECT REPRESENTS This object implements a layer in a deep neural network that applies a lower triangular mask to its input tensor. The mask is defined such that all elements above the specified diagonal are set to a given value. The diagonal offset and the mask value are determined by the template parameters. DIAGONAL VALUE DETERMINATION - If tag_ is neg_infinity_tag: diagonal value is set to negative infinity. - If tag_ is zero_tag: diagonal value is set to zero. - If tag_ is void: diagonal value is set to num_ / den_ as a float. DIAGONAL OFFSET The diag_ parameter determines the diagonal above which elements are masked: - diag_ = 0: main diagonal - diag_ > 0: diag_ steps above the main diagonal - diag_ < 0: |diag_| steps below the main diagonal EXAMPLE USAGE // Create a layer that masks all elements above the main diagonal with -inf tril_<0, neg_infinity_tag> layer1; // Create a layer that masks all elements above the main diagonal with 0 tril_<0, zero_tag> layer2; // Create a layer that masks all elements above the main diagonal with 0.5 tril_<0, void, 1, 2> layer3; // Create a layer that masks all elements 5 positions above the main diagonal with -inf tril_<5, neg_infinity_tag> layer4; // Create a layer that masks all elements 3 positions below the main diagonal with 0.25 tril_<-3, void, 1, 4> layer5; SERIALIZATION SUPPORT This object supports serialization and deserialization via the serialize() and deserialize() functions. !*/ public: tril_() = default; /*! ensures - This object is properly initialized. !*/ template void setup(const SUBNET& sub); /*! requires - SUBNET is a valid network layer type. ensures - Initializes the mask based on the dimensions of the input tensor from sub. !*/ template void forward(const SUBNET& sub, resizable_tensor& output); /*! requires - SUBNET is a valid network layer type. ensures - Applies the lower triangular mask to the input tensor from sub and stores the result in output. !*/ template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); /*! requires - SUBNET is a valid network layer type. ensures - Computes the gradient of the loss with respect to the input tensor and stores it in sub. !*/ inline dpoint map_input_to_output(const dpoint& p) const; /*! ensures - Maps a point from the input tensor to the corresponding point in the output tensor. !*/ inline dpoint map_output_to_input(const dpoint& p) const; /*! ensures - Maps a point from the output tensor to the corresponding point in the input tensor. !*/ const tensor& get_layer_params() const; /*! ensures - Returns the parameters of this layer. !*/ tensor& get_layer_params(); /*! ensures - Returns the parameters of this layer. !*/ friend void serialize(const tril_& item, std::ostream& out); /*! ensures - Serializes the state of this object to the given output stream. !*/ friend void deserialize(tril_& item, std::istream& in); /*! ensures - Deserializes the state of this object from the given input stream. !*/ friend std::ostream& operator<<(std::ostream& out, const tril_& item); /*! ensures - Prints a human-readable representation of this object to the given output stream. !*/ friend void to_xml(const tril_& item, std::ostream& out); /*! ensures - Serializes the state of this object to XML format and writes it to the given output stream. !*/ }; template using tril = add_layer, SUBNET>; template using tril_mask = add_layer, SUBNET>; template using tril_diag = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- template class adaptive_computation_time_ { /*! REQUIREMENTS ON TEMPLATE ARGUMENTS - max_steps > 0 WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface defined above. It implements Adaptive Computation Time (ACT) following Graves (2016) "Adaptive Computation Time for Recurrent Neural Networks" (arXiv:1603.08983). ACT allows the network to adaptively determine how many computational steps to perform for each position in the input sequence, spending more computation on difficult parts while quickly processing easier parts. MATHEMATICAL FOUNDATION: Core ACT Algorithm: - For each sequence position t, perform up to max_steps computational steps - At step n, compute halting probability: p_t^n = sigmoid(W_halt^T * s_t^n + b_halt) - Cumulative halting probability: h_t^n = sum_{i=1 to n} p_t^i - Stop when h_t^n >= theta (threshold, typically 0.99) - Final output: y_t = sum_{n=1 to N(t)} alpha_t^n * y_hat_t^n Where alpha_t^n (effective weight) is computed as: - alpha_t^n = p_t^n * rho_t^{n-1} for intermediate steps - alpha_t^{N(t)} = 1 - h_t^{N(t)-1} (remainder for final step) - rho_t^n = 1 - h_t^n (remaining probability mass) PONDER COST (Regularization): - R(x) = (1/T) * sum_t (N(t) + rho_t^{N(t)}) - Total loss: L = L_task + lambda * R(x) - lambda is controlled by get_ponder_penalty() IMPLEMENTATION DETAILS: - Input/Output tensors have identical dimensions - Learnable parameters: W_halt in R^{d x k}, b_halt in R - State tracking per position: cumulative halting, remainders, step counts - Early termination when all positions halt !*/ public: adaptive_computation_time_(); /*! ensures - #get_max_steps() == max_steps - #get_halt_threshold() == 0.99f - #get_ponder_penalty() == 0.01f !*/ // Configuration accessors long get_max_steps() const; float get_halt_threshold() const; float get_ponder_penalty() const; void set_halt_threshold(float threshold); /*! requires - 0 < threshold <= 1.0f ensures - #get_halt_threshold() == threshold !*/ void set_ponder_penalty(float penalty); /*! requires - penalty >= 0 ensures - #get_ponder_penalty() == penalty !*/ // Runtime statistics float get_ponder_cost() const; /*! ensures - returns the ponder cost R(x) from the most recent forward pass - value represents average computational cost per position !*/ float get_average_steps() const; /*! ensures - returns the average number of computation steps per position - value is between 1.0 and max_steps !*/ // Layer interface template void setup(const SUBNET& sub); template void forward(const SUBNET& sub, resizable_tensor& output); template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); dpoint map_input_to_output(dpoint p) const; dpoint map_output_to_input(dpoint p) const; const tensor& get_layer_params() const; tensor& get_layer_params(); friend void serialize(const adaptive_computation_time_& item, std::ostream& out); friend void deserialize(adaptive_computation_time_& item, std::istream& in); friend std::ostream& operator<<(std::ostream& out, const adaptive_computation_time_& item); friend void to_xml(const adaptive_computation_time_& item, std::ostream& out); /*! provides serialization support and output operators !*/ }; template using adaptive_computation_time = add_layer, SUBNET>; template using act = add_layer, SUBNET>; template using act4 = add_layer, SUBNET>; template using act16 = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- } #endif // DLIB_DNn_LAYERS_ABSTRACT_H_ ================================================ FILE: dlib/dnn/loss.h ================================================ // Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_DNn_LOSS_H_ #define DLIB_DNn_LOSS_H_ #include "loss_abstract.h" #include "core.h" #include "utilities.h" #include "../matrix.h" #include "../cuda/tensor_tools.h" #include "../geometry.h" #include "../image_processing/box_overlap_testing.h" #include "../image_processing/full_object_detection.h" #include "../svm/ranking_tools.h" #include #include #include namespace dlib { // ---------------------------------------------------------------------------------------- class loss_binary_hinge_ { public: typedef float training_label_type; typedef float output_label_type; template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { DLIB_CASSERT(sub.sample_expansion_factor() == 1); const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 && output_tensor.k() == 1); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { *iter++ = out_data[i]; } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 && output_tensor.k() == 1); // The loss we output is the average loss over the mini-batch. const double scale = 1.0/output_tensor.num_samples(); double loss = 0; const float* out_data = output_tensor.host(); float* g = grad.host_write_only(); for (long i = 0; i < output_tensor.num_samples(); ++i) { const float y = *truth++; DLIB_CASSERT(y == +1 || y == -1, "y: " << y); const float temp = 1-y*out_data[i]; if (temp > 0) { loss += scale*temp; g[i] = -scale*y; } else { g[i] = 0; } } return loss; } friend void serialize(const loss_binary_hinge_& , std::ostream& out) { serialize("loss_binary_hinge_", out); } friend void deserialize(loss_binary_hinge_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_binary_hinge_") throw serialization_error("Unexpected version found while deserializing dlib::loss_binary_hinge_."); } friend std::ostream& operator<<(std::ostream& out, const loss_binary_hinge_& ) { out << "loss_binary_hinge"; return out; } friend void to_xml(const loss_binary_hinge_& /*item*/, std::ostream& out) { out << "\n"; } }; template using loss_binary_hinge = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_binary_log_ { public: typedef float training_label_type; typedef float output_label_type; template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { DLIB_CASSERT(sub.sample_expansion_factor() == 1); const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 && output_tensor.k() == 1); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { *iter++ = out_data[i]; } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 && output_tensor.k() == 1); DLIB_CASSERT(grad.nr() == 1 && grad.nc() == 1 && grad.k() == 1); tt::sigmoid(grad, output_tensor); // The loss we output is the average loss over the mini-batch. const double scale = 1.0/output_tensor.num_samples(); double loss = 0; float* g = grad.host(); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { const float y = *truth++; DLIB_CASSERT(y != 0, "y: " << y); float temp; if (y > 0) { temp = log1pexp(-out_data[i]); loss += y*scale*temp; g[i] = y*scale*(g[i]-1); } else { temp = -(-out_data[i]-log1pexp(-out_data[i])); loss += -y*scale*temp; g[i] = -y*scale*g[i]; } } return loss; } friend void serialize(const loss_binary_log_& , std::ostream& out) { serialize("loss_binary_log_", out); } friend void deserialize(loss_binary_log_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_binary_log_") throw serialization_error("Unexpected version found while deserializing dlib::loss_binary_log_."); } friend std::ostream& operator<<(std::ostream& out, const loss_binary_log_& ) { out << "loss_binary_log"; return out; } friend void to_xml(const loss_binary_log_& /*item*/, std::ostream& out) { out << "\n"; } }; template using loss_binary_log = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_multiclass_log_ { public: typedef unsigned long training_label_type; typedef unsigned long output_label_type; template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 ); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); // Note that output_tensor.k() should match the number of labels. for (long i = 0; i < output_tensor.num_samples(); ++i) { // The index of the largest output for this sample is the label. *iter++ = index_of_max(rowm(mat(output_tensor),i)); } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1); DLIB_CASSERT(grad.nr() == 1 && grad.nc() == 1); tt::softmax(grad, output_tensor); // The loss we output is the average loss over the mini-batch. const double scale = 1.0/output_tensor.num_samples(); double loss = 0; float* g = grad.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { const long y = (long)*truth++; // The network must produce a number of outputs that is equal to the number // of labels when using this type of loss. DLIB_CASSERT(y < output_tensor.k(), "y: " << y << ", output_tensor.k(): " << output_tensor.k()); for (long k = 0; k < output_tensor.k(); ++k) { const unsigned long idx = i*output_tensor.k()+k; if (k == y) { loss += scale*-safe_log(g[idx]); g[idx] = scale*(g[idx]-1); } else { g[idx] = scale*g[idx]; } } } return loss; } friend void serialize(const loss_multiclass_log_& , std::ostream& out) { serialize("loss_multiclass_log_", out); } friend void deserialize(loss_multiclass_log_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_multiclass_log_") throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_."); } friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_& ) { out << "loss_multiclass_log"; return out; } friend void to_xml(const loss_multiclass_log_& /*item*/, std::ostream& out) { out << "\n"; } }; template using loss_multiclass_log = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_multiclass_log_weighted_ { public: typedef dlib::weighted_label weighted_label; typedef weighted_label training_label_type; typedef unsigned long output_label_type; template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 ); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); // Note that output_tensor.k() should match the number of labels. for (long i = 0; i < output_tensor.num_samples(); ++i) { // The index of the largest output for this sample is the label. *iter++ = index_of_max(rowm(mat(output_tensor),i)); } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1); DLIB_CASSERT(grad.nr() == 1 && grad.nc() == 1); tt::softmax(grad, output_tensor); // The loss we output is the average loss over the mini-batch. const double scale = 1.0/output_tensor.num_samples(); double loss = 0; float* g = grad.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { const auto wl = *truth++; const long y = wl.label; const float weight = wl.weight; // The network must produce a number of outputs that is equal to the number // of labels when using this type of loss. DLIB_CASSERT(y < output_tensor.k(), "y: " << y << ", output_tensor.k(): " << output_tensor.k()); for (long k = 0; k < output_tensor.k(); ++k) { const unsigned long idx = i*output_tensor.k()+k; if (k == y) { loss += weight*scale*-safe_log(g[idx]); g[idx] =weight*scale*(g[idx]-1); } else { g[idx] = weight*scale*g[idx]; } } } return loss; } friend void serialize(const loss_multiclass_log_weighted_& , std::ostream& out) { serialize("loss_multiclass_log_weighted_", out); } friend void deserialize(loss_multiclass_log_weighted_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_multiclass_log_weighted_") throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_weighted_."); } friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_weighted_& ) { out << "loss_multiclass_log_weighted"; return out; } friend void to_xml(const loss_multiclass_log_weighted_& /*item*/, std::ostream& out) { out << "\n"; } }; template using loss_multiclass_log_weighted = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_multimulticlass_log_ { public: loss_multimulticlass_log_ () = default; loss_multimulticlass_log_ ( const std::map>& labels ) { for (auto& l : labels) { possible_labels[l.first] = std::make_shared(l.second); DLIB_CASSERT(l.second.size() >= 2, "Each classifier must have at least two possible labels."); for (size_t i = 0; i < l.second.size(); ++i) { label_idx_lookup[l.first][l.second[i]] = i; ++total_num_labels; } } } unsigned long number_of_labels() const { return total_num_labels; } unsigned long number_of_classifiers() const { return possible_labels.size(); } std::map> get_labels ( ) const { std::map> info; for (auto& i : possible_labels) { for (auto& label : *i.second) info[i.first].emplace_back(label); } return info; } class classifier_output { public: classifier_output() = default; size_t num_classes() const { return class_probs.size(); } double probability_of_class ( size_t i ) const { DLIB_CASSERT(i < num_classes()); return class_probs(i); } const std::string& label( size_t i ) const { DLIB_CASSERT(i < num_classes()); return (*_labels)[i]; } operator std::string( ) const { DLIB_CASSERT(num_classes() != 0); return (*_labels)[index_of_max(class_probs)]; } friend std::ostream& operator<< (std::ostream& out, const classifier_output& item) { DLIB_ASSERT(item.num_classes() != 0); out << static_cast(item); return out; } private: friend class loss_multimulticlass_log_; template classifier_output( const matrix_exp& class_probs, const std::shared_ptr>& _labels ) : class_probs(class_probs), _labels(_labels) { } matrix class_probs; std::shared_ptr> _labels; }; typedef std::map training_label_type; typedef std::map output_label_type; template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter_begin ) const { const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 ); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(number_of_labels() != 0, "You must give the loss_multimulticlass_log_'s constructor label data before you can use it!"); DLIB_CASSERT(output_tensor.k() == (long)number_of_labels(), "The output tensor must have " << number_of_labels() << " channels."); long k_offset = 0; for (auto& l : possible_labels) { auto iter = iter_begin; const std::string& classifier_name = l.first; const auto& labels = (*l.second); scratch.set_size(output_tensor.num_samples(), labels.size()); tt::copy_tensor(false, scratch, 0, output_tensor, k_offset, labels.size()); tt::softmax(scratch, scratch); for (long i = 0; i < scratch.num_samples(); ++i) (*iter++)[classifier_name] = classifier_output(rowm(mat(scratch),i), l.second); k_offset += labels.size(); } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth_begin, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1); DLIB_CASSERT(grad.nr() == 1 && grad.nc() == 1); DLIB_CASSERT(number_of_labels() != 0, "You must give the loss_multimulticlass_log_'s constructor label data before you can use it!"); DLIB_CASSERT(output_tensor.k() == (long)number_of_labels(), "The output tensor must have " << number_of_labels() << " channels."); // The loss we output is the average loss over the mini-batch. const double scale = 1.0/output_tensor.num_samples(); double loss = 0; long k_offset = 0; for (auto& l : label_idx_lookup) { const std::string& classifier_name = l.first; const auto& int_labels = l.second; scratch.set_size(output_tensor.num_samples(), int_labels.size()); tt::copy_tensor(false, scratch, 0, output_tensor, k_offset, int_labels.size()); tt::softmax(scratch, scratch); auto truth = truth_begin; float* g = scratch.host(); for (long i = 0; i < scratch.num_samples(); ++i) { const long y = int_labels.at(truth->at(classifier_name)); ++truth; for (long k = 0; k < scratch.k(); ++k) { const unsigned long idx = i*scratch.k()+k; if (k == y) { loss += scale*-std::log(g[idx]); g[idx] = scale*(g[idx]-1); } else { g[idx] = scale*g[idx]; } } } tt::copy_tensor(false, grad, k_offset, scratch, 0, int_labels.size()); k_offset += int_labels.size(); } return loss; } friend void serialize(const loss_multimulticlass_log_& item, std::ostream& out) { serialize("loss_multimulticlass_log_", out); serialize(item.get_labels(), out); } friend void deserialize(loss_multimulticlass_log_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_multimulticlass_log_") throw serialization_error("Unexpected version found while deserializing dlib::loss_multimulticlass_log_."); std::map> info; deserialize(info, in); item = loss_multimulticlass_log_(info); } friend std::ostream& operator<<(std::ostream& out, const loss_multimulticlass_log_& item) { out << "loss_multimulticlass_log, labels={"; for (auto i = item.possible_labels.begin(); i != item.possible_labels.end(); ) { auto& category = i->first; auto& labels = *(i->second); out << category << ":("; for (size_t j = 0; j < labels.size(); ++j) { out << labels[j]; if (j+1 < labels.size()) out << ","; } out << ")"; if (++i != item.possible_labels.end()) out << ", "; } out << "}"; return out; } friend void to_xml(const loss_multimulticlass_log_& item, std::ostream& out) { out << "\n"; out << item; out << "\n\n"; } private: std::map>> possible_labels; unsigned long total_num_labels = 0; // We make it true that: possible_labels[classifier][label_idx_lookup[classifier][label]] == label std::map> label_idx_lookup; // Scratch doesn't logically contribute to the state of this object. It's just // temporary scratch space used by this class. mutable resizable_tensor scratch; }; template using loss_multimulticlass_log = add_loss_layer; inline bool operator== (const std::string& lhs, const loss_multimulticlass_log_::classifier_output& rhs) { return lhs == static_cast(rhs); } inline bool operator== (const loss_multimulticlass_log_::classifier_output& lhs, const std::string& rhs) { return rhs == static_cast(lhs); } // ---------------------------------------------------------------------------------------- class loss_multibinary_log_ { public: typedef std::vector training_label_type; typedef std::vector output_label_type; loss_multibinary_log_() = default; loss_multibinary_log_(double gamma) : gamma(gamma) { DLIB_CASSERT(gamma >= 0); } template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); // Note that output_tensor.k() should match the number of labels. const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { output_label_type predictions(output_tensor.k(), 0); for (long k = 0; k < output_tensor.k(); ++k) { predictions[k] = out_data[i * output_tensor.k() + k]; } *iter++ = std::move(predictions); } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples() % sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1); DLIB_CASSERT(grad.nr() == 1 && grad.nc() == 1); tt::sigmoid(grad, output_tensor); // The loss we output is the average loss over the mini-batch. const double scale = 1.0 / output_tensor.num_samples(); double loss = 0; float* g = grad.host(); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth) { const long long num_label_categories = truth->size(); DLIB_CASSERT(output_tensor.k() == num_label_categories, "Number of label types should match the number of output channels. " "output_tensor.k(): " << output_tensor.k() << ", num_label_categories: "<< num_label_categories); for (long k = 0; k < output_tensor.k(); ++k) { const float y = (*truth)[k]; DLIB_CASSERT(y != 0, "y: " << y); const size_t idx = i * output_tensor.k() + k; if (y > 0) { const float temp = log1pexp(-out_data[idx]); const float focus = std::pow(1 - g[idx], gamma); loss += y * scale * temp * focus; g[idx] = y * scale * focus * (g[idx] * (gamma * temp + 1) - 1); } else { const float temp = -(-out_data[idx] - log1pexp(-out_data[idx])); const float focus = std::pow(g[idx], gamma); loss += -y * scale * temp * focus; g[idx] = -y * scale * focus * g[idx] * (gamma * temp + 1); } } } return loss; } double get_gamma () const { return gamma; } friend void serialize(const loss_multibinary_log_& item, std::ostream& out) { serialize("loss_multibinary_log_2", out); serialize(item.gamma, out); } friend void deserialize(loss_multibinary_log_& item, std::istream& in) { std::string version; deserialize(version, in); if (version == "loss_multibinary_log_") { item.gamma = 0; return; } else if (version == "loss_multibinary_log_2") { deserialize(item.gamma, in); } else { throw serialization_error("Unexpected version found while deserializing dlib::loss_multibinary_log_."); } } friend std::ostream& operator<<(std::ostream& out, const loss_multibinary_log_& item) { out << "loss_multibinary_log (gamma=" << item.gamma << ")"; return out; } friend void to_xml(const loss_multibinary_log_& item, std::ostream& out) { out << "\n"; } private: double gamma = 0; }; template using loss_multibinary_log = add_loss_layer; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- enum class use_image_pyramid : uint8_t { no, yes }; struct mmod_options { public: struct detector_window_details { detector_window_details() = default; detector_window_details(unsigned long w, unsigned long h) : width(w), height(h) {} detector_window_details(unsigned long w, unsigned long h, const std::string& l) : width(w), height(h), label(l) {} unsigned long width = 0; unsigned long height = 0; std::string label; friend inline void serialize(const detector_window_details& item, std::ostream& out) { int version = 2; serialize(version, out); serialize(item.width, out); serialize(item.height, out); serialize(item.label, out); } friend inline void deserialize(detector_window_details& item, std::istream& in) { int version = 0; deserialize(version, in); if (version != 1 && version != 2) throw serialization_error("Unexpected version found while deserializing dlib::mmod_options::detector_window_details"); deserialize(item.width, in); deserialize(item.height, in); if (version == 2) deserialize(item.label, in); } }; mmod_options() = default; std::vector detector_windows; double loss_per_false_alarm = 1; double loss_per_missed_target = 1; double truth_match_iou_threshold = 0.5; test_box_overlap overlaps_nms = test_box_overlap(0.4); test_box_overlap overlaps_ignore; bool use_bounding_box_regression = false; double bbr_lambda = 100; // This field is intentionally not serialized because I want people to really think hard // about ignoring the warnings that this suppresses. bool be_quiet = false; use_image_pyramid assume_image_pyramid = use_image_pyramid::yes; mmod_options ( const std::vector>& boxes, const unsigned long target_size, // We want the length of the longest dimension of the detector window to be this. const unsigned long min_target_size, // But we require that the smallest dimension of the detector window be at least this big. const double min_detector_window_overlap_iou = 0.75 ) { DLIB_CASSERT(0 < min_target_size && min_target_size <= target_size); DLIB_CASSERT(0.5 < min_detector_window_overlap_iou && min_detector_window_overlap_iou < 1); // Figure out what detector windows we will need. for (auto& label : get_labels(boxes)) { for (auto ratio : find_covering_aspect_ratios(boxes, test_box_overlap(min_detector_window_overlap_iou), label)) { double detector_width; double detector_height; if (ratio < 1) { detector_height = target_size; detector_width = ratio*target_size; if (detector_width < min_target_size) { detector_height = min_target_size/ratio; detector_width = min_target_size; } } else { detector_width = target_size; detector_height = target_size/ratio; if (detector_height < min_target_size) { detector_width = min_target_size*ratio; detector_height = min_target_size; } } detector_window_details p((unsigned long)std::round(detector_width), (unsigned long)std::round(detector_height), label); detector_windows.push_back(p); } } DLIB_CASSERT(detector_windows.size() != 0, "You can't call mmod_options's constructor with a set of boxes that is empty (or only contains ignored boxes)."); set_overlap_nms(boxes); } mmod_options( use_image_pyramid assume_image_pyramid, const std::vector>& boxes, const double min_detector_window_overlap_iou = 0.75 ) : assume_image_pyramid(assume_image_pyramid) { DLIB_CASSERT(assume_image_pyramid == use_image_pyramid::no); DLIB_CASSERT(0.5 < min_detector_window_overlap_iou && min_detector_window_overlap_iou < 1); // Figure out what detector windows we will need. for (auto& label : get_labels(boxes)) { for (auto rectangle : find_covering_rectangles(boxes, test_box_overlap(min_detector_window_overlap_iou), label)) { detector_windows.push_back(detector_window_details(rectangle.width(), rectangle.height(), label)); } } DLIB_CASSERT(detector_windows.size() != 0, "You can't call mmod_options's constructor with a set of boxes that is empty (or only contains ignored boxes)."); set_overlap_nms(boxes); } private: void set_overlap_nms(const std::vector>& boxes) { // Convert from mmod_rect to rectangle so we can call // find_tight_overlap_tester(). std::vector> temp; for (auto&& bi : boxes) { std::vector rtemp; for (auto&& b : bi) { if (b.ignore) continue; rtemp.push_back(b.rect); } temp.push_back(std::move(rtemp)); } overlaps_nms = find_tight_overlap_tester(temp); // Relax the non-max-suppression a little so that it doesn't accidentally make // it impossible for the detector to output boxes matching the training data. // This could be a problem with the tightest possible nms test since there is // some small variability in how boxes get positioned between the training data // and the coordinate system used by the detector when it runs. So relaxing it // here takes care of that. auto iou_thresh = advance_toward_1(overlaps_nms.get_iou_thresh()); auto percent_covered_thresh = advance_toward_1(overlaps_nms.get_percent_covered_thresh()); overlaps_nms = test_box_overlap(iou_thresh, percent_covered_thresh); } static double advance_toward_1 ( double val ) { if (val < 1) val += (1-val)*0.1; return val; } static size_t count_overlaps ( const std::vector& rects, const test_box_overlap& overlaps, const rectangle& ref_box ) { size_t cnt = 0; for (auto& b : rects) { if (overlaps(b, ref_box)) ++cnt; } return cnt; } static std::vector find_rectangles_overlapping_all_others ( std::vector rects, const test_box_overlap& overlaps ) { std::vector exemplars; dlib::rand rnd; while(rects.size() > 0) { // Pick boxes at random and see if they overlap a lot of other boxes. We will try // 500 different boxes each iteration and select whichever hits the most others to // add to our exemplar set. rectangle best_ref_box; size_t best_cnt = 0; for (int iter = 0; iter < 500; ++iter) { rectangle ref_box = rects[rnd.get_random_64bit_number()%rects.size()]; size_t cnt = count_overlaps(rects, overlaps, ref_box); if (cnt >= best_cnt) { best_cnt = cnt; best_ref_box = ref_box; } } // Now mark all the boxes the new ref box hit as hit. for (size_t i = 0; i < rects.size(); ++i) { if (overlaps(rects[i], best_ref_box)) { // remove box from rects so we don't hit it again later swap(rects[i], rects.back()); rects.pop_back(); --i; } } exemplars.push_back(best_ref_box); } return exemplars; } static std::set get_labels ( const std::vector>& rects ) { std::set labels; for (auto& rr : rects) { for (auto& r : rr) labels.insert(r.label); } return labels; } static std::vector find_covering_aspect_ratios ( const std::vector>& rects, const test_box_overlap& overlaps, const std::string& label ) { std::vector boxes; // Make sure all the boxes have the same size and position, so that the only thing our // checks for overlap will care about is aspect ratio (i.e. scale and x,y position are // ignored). for (auto& bb : rects) { for (auto&& b : bb) { if (!b.ignore && b.label == label) boxes.push_back(move_rect(set_rect_area(b.rect,400*400), point(0,0))); } } std::vector ratios; for (auto r : find_rectangles_overlapping_all_others(boxes, overlaps)) ratios.push_back(r.width()/(double)r.height()); return ratios; } static std::vector find_covering_rectangles ( const std::vector>& rects, const test_box_overlap& overlaps, const std::string& label ) { std::vector boxes; // Make sure all the boxes have the same position, so that the we only check for // width and height. for (auto& bb : rects) { for (auto&& b : bb) { if (!b.ignore && b.label == label) boxes.push_back(rectangle(b.rect.width(), b.rect.height())); } } return find_rectangles_overlapping_all_others(boxes, overlaps); } }; inline void serialize(const mmod_options& item, std::ostream& out) { int version = 4; serialize(version, out); serialize(item.detector_windows, out); serialize(item.loss_per_false_alarm, out); serialize(item.loss_per_missed_target, out); serialize(item.truth_match_iou_threshold, out); serialize(item.overlaps_nms, out); serialize(item.overlaps_ignore, out); serialize(static_cast(item.assume_image_pyramid), out); serialize(item.use_bounding_box_regression, out); serialize(item.bbr_lambda, out); } inline void deserialize(mmod_options& item, std::istream& in) { int version = 0; deserialize(version, in); if (!(1 <= version && version <= 4)) throw serialization_error("Unexpected version found while deserializing dlib::mmod_options"); if (version == 1) { unsigned long width; unsigned long height; deserialize(width, in); deserialize(height, in); item.detector_windows = {mmod_options::detector_window_details(width, height)}; } else { deserialize(item.detector_windows, in); } deserialize(item.loss_per_false_alarm, in); deserialize(item.loss_per_missed_target, in); deserialize(item.truth_match_iou_threshold, in); deserialize(item.overlaps_nms, in); deserialize(item.overlaps_ignore, in); item.assume_image_pyramid = use_image_pyramid::yes; if (version >= 3) { uint8_t assume_image_pyramid = 0; deserialize(assume_image_pyramid, in); item.assume_image_pyramid = static_cast(assume_image_pyramid); } item.use_bounding_box_regression = mmod_options().use_bounding_box_regression; // use default value since this wasn't provided item.bbr_lambda = mmod_options().bbr_lambda; // use default value since this wasn't provided if (version >= 4) { deserialize(item.use_bounding_box_regression, in); deserialize(item.bbr_lambda, in); } } inline std::ostream& operator<<(std::ostream& out, const std::vector& detector_windows) { // write detector windows grouped by label // example output: aeroplane:74x30,131x30,70x45,54x70,198x30;bicycle:70x57,32x70,70x32,51x70,128x30,30x121;car:70x36,70x60,99x30,52x70,30x83,30x114,30x200 std::map> detector_windows_by_label; for (const auto& detector_window : detector_windows) detector_windows_by_label[detector_window.label].push_back(detector_window); size_t label_count = 0; for (const auto& i : detector_windows_by_label) { const auto& label = i.first; const auto& detector_windows = i.second; if (label_count++ > 0) out << ";"; out << label << ":"; for (size_t j = 0; j < detector_windows.size(); ++j) { out << detector_windows[j].width << "x" << detector_windows[j].height; if (j + 1 < detector_windows.size()) out << ","; } } return out; } // ---------------------------------------------------------------------------------------- class loss_mmod_ { struct intermediate_detection { intermediate_detection() = default; intermediate_detection( rectangle rect_ ) : rect(rect_), rect_bbr(rect_) {} intermediate_detection( rectangle rect_, double detection_confidence_, size_t tensor_offset_, long channel ) : rect(rect_), detection_confidence(detection_confidence_), tensor_offset(tensor_offset_), tensor_channel(channel), rect_bbr(rect_) {} // rect is the rectangle you get without any bounding box regression. So it's // the basic sliding window box (aka, the "anchor box"). rectangle rect; double detection_confidence = 0; size_t tensor_offset = 0; long tensor_channel = 0; // rect_bbr = rect + bounding box regression. So more accurate. Or if bbr is off then // this is just rect. The important thing about rect_bbr is that its the // rectangle we use for doing NMS. drectangle rect_bbr; size_t tensor_offset_dx = 0; size_t tensor_offset_dy = 0; size_t tensor_offset_dw = 0; size_t tensor_offset_dh = 0; bool operator<(const intermediate_detection& item) const { return detection_confidence < item.detection_confidence; } }; public: typedef std::vector training_label_type; typedef std::vector output_label_type; loss_mmod_() {} loss_mmod_(mmod_options options_) : options(options_) {} const mmod_options& get_options ( ) const { return options; } template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter, double adjust_threshold = 0 ) const { const tensor& output_tensor = sub.get_output(); if (options.use_bounding_box_regression) { DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size()*5); } else { DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size()); } DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(sub.sample_expansion_factor() == 1, sub.sample_expansion_factor()); std::vector dets_accum; output_label_type final_dets; for (long i = 0; i < output_tensor.num_samples(); ++i) { tensor_to_dets(input_tensor, output_tensor, i, dets_accum, adjust_threshold, sub); // Do non-max suppression final_dets.clear(); for (unsigned long i = 0; i < dets_accum.size(); ++i) { if (overlaps_any_box_nms(final_dets, dets_accum[i].rect_bbr)) continue; final_dets.push_back(mmod_rect(dets_accum[i].rect_bbr, dets_accum[i].detection_confidence, options.detector_windows[dets_accum[i].tensor_channel].label)); } *iter++ = std::move(final_dets); } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); if (options.use_bounding_box_regression) { DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size()*5); } else { DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size()); } double det_thresh_speed_adjust = 0; // we will scale the loss so that it doesn't get really huge const double scale = 1.0/(output_tensor.nr()*output_tensor.nc()*output_tensor.num_samples()*options.detector_windows.size()); double loss = 0; float* g = grad.host_write_only(); for (size_t i = 0; i < grad.size(); ++i) g[i] = 0; const float* out_data = output_tensor.host(); std::vector dets; for (long i = 0; i < output_tensor.num_samples(); ++i) { tensor_to_dets(input_tensor, output_tensor, i, dets, -options.loss_per_false_alarm + det_thresh_speed_adjust, sub); const unsigned long max_num_dets = 50 + truth->size()*5; // Prevent calls to tensor_to_dets() from running for a really long time // due to the production of an obscene number of detections. const unsigned long max_num_initial_dets = max_num_dets*100; if (dets.size() > max_num_initial_dets) { det_thresh_speed_adjust = std::max(det_thresh_speed_adjust,dets[max_num_initial_dets].detection_confidence + options.loss_per_false_alarm); } std::vector truth_idxs; truth_idxs.reserve(truth->size()); std::unordered_map idx_to_truth_rect; // The loss will measure the number of incorrect detections. A detection is // incorrect if it doesn't hit a truth rectangle or if it is a duplicate detection // on a truth rectangle. loss += truth->size()*options.loss_per_missed_target; for (auto&& x : *truth) { if (!x.ignore) { size_t k; point p; if(image_rect_to_feat_coord(p, input_tensor, x, x.label, sub, k, options.assume_image_pyramid)) { // Ignore boxes that can't be detected by the CNN. loss -= options.loss_per_missed_target; truth_idxs.push_back(-1); continue; } const size_t idx = (k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x(); const auto i = idx_to_truth_rect.find(idx); if (i != idx_to_truth_rect.end()) { if (!options.be_quiet) { // Ignore duplicate truth box in feature coordinates. std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << x.rect; std::cout << ", and we are ignoring it because it maps to the exact same feature coordinates "; std::cout << "as another truth rectangle located at " << i->second << "." << std::endl; } loss -= options.loss_per_missed_target; truth_idxs.push_back(-1); continue; } loss -= out_data[idx]; // compute gradient g[idx] = -scale; truth_idxs.push_back(idx); idx_to_truth_rect[idx] = x.rect; } else { // This box was ignored so shouldn't have been counted in the loss. loss -= options.loss_per_missed_target; truth_idxs.push_back(-1); } } // Measure the loss augmented score for the detections which hit a truth rect. std::vector truth_score_hits(truth->size(), 0); // keep track of which truth boxes we have hit so far. std::vector hit_truth_table(truth->size(), false); std::vector final_dets; // The point of this loop is to fill out the truth_score_hits array. for (size_t i = 0; i < dets.size() && final_dets.size() < max_num_dets; ++i) { if (overlaps_any_box_nms(final_dets, dets[i].rect_bbr)) continue; const auto& det_label = options.detector_windows[dets[i].tensor_channel].label; const std::pair hittruth = find_best_match(*truth, hit_truth_table, dets[i].rect, det_label); final_dets.push_back(dets[i].rect); const double truth_match = hittruth.first; // if hit truth rect if (truth_match > options.truth_match_iou_threshold) { // if this is the first time we have seen a detect which hit (*truth)[hittruth.second] const double score = dets[i].detection_confidence; if (hit_truth_table[hittruth.second] == false) { hit_truth_table[hittruth.second] = true; truth_score_hits[hittruth.second] += score; } else { truth_score_hits[hittruth.second] += score + options.loss_per_false_alarm; } } } // Check if any of the truth boxes are unobtainable because the NMS is // killing them. If so, automatically set those unobtainable boxes to // ignore and print a warning message to the user. for (size_t i = 0; i < hit_truth_table.size(); ++i) { if (!hit_truth_table[i] && !(*truth)[i].ignore) { // So we didn't hit this truth box. Is that because there is // another, different truth box, that overlaps it according to NMS? const std::pair hittruth = find_best_match(*truth, (*truth)[i], i); if (hittruth.second == i || (*truth)[hittruth.second].ignore) continue; rectangle best_matching_truth_box = (*truth)[hittruth.second]; if (options.overlaps_nms(best_matching_truth_box, (*truth)[i])) { const int idx = truth_idxs[i]; if (idx != -1) { // We are ignoring this box so we shouldn't have counted it in the // loss in the first place. So we subtract out the loss values we // added for it in the code above. loss -= options.loss_per_missed_target-out_data[idx]; g[idx] = 0; if (!options.be_quiet) { std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << (*truth)[i].rect; std::cout << " that is suppressed by non-max-suppression "; std::cout << "because it is overlapped by another truth rectangle located at " << best_matching_truth_box << " (IoU:"<< box_intersection_over_union(best_matching_truth_box,(*truth)[i]) <<", Percent covered:" << box_percent_covered(best_matching_truth_box,(*truth)[i]) << ")." << std::endl; } } } } } hit_truth_table.assign(hit_truth_table.size(), false); final_dets.clear(); // Now figure out which detections jointly maximize the loss and detection score sum. We // need to take into account the fact that allowing a true detection in the output, while // initially reducing the loss, may allow us to increase the loss later with many duplicate // detections. for (unsigned long i = 0; i < dets.size() && final_dets.size() < max_num_dets; ++i) { if (overlaps_any_box_nms(final_dets, dets[i].rect_bbr)) continue; const auto& det_label = options.detector_windows[dets[i].tensor_channel].label; const std::pair hittruth = find_best_match(*truth, hit_truth_table, dets[i].rect, det_label); const double truth_match = hittruth.first; if (truth_match > options.truth_match_iou_threshold) { if (truth_score_hits[hittruth.second] > options.loss_per_missed_target) { if (!hit_truth_table[hittruth.second]) { hit_truth_table[hittruth.second] = true; final_dets.push_back(dets[i]); loss -= options.loss_per_missed_target; // Now account for BBR loss and gradient if appropriate. if (options.use_bounding_box_regression) { double dx = out_data[dets[i].tensor_offset_dx]; double dy = out_data[dets[i].tensor_offset_dy]; double dw = out_data[dets[i].tensor_offset_dw]; double dh = out_data[dets[i].tensor_offset_dh]; dpoint p = dcenter(dets[i].rect); double w = dets[i].rect.width()-1; double h = dets[i].rect.height()-1; drectangle truth_box = (*truth)[hittruth.second].rect; dpoint p_truth = dcenter(truth_box); DLIB_CASSERT(w > 0); DLIB_CASSERT(h > 0); double target_dx = (p_truth.x() - p.x())/w; double target_dy = (p_truth.y() - p.y())/h; double target_dw = std::log((truth_box.width()-1)/w); double target_dh = std::log((truth_box.height()-1)/h); // compute smoothed L1 loss on BBR outputs. This loss // is just the MSE loss when the loss is small and L1 // when large. dx = dx-target_dx; dy = dy-target_dy; dw = dw-target_dw; dh = dh-target_dh; // use smoothed L1 double ldx = std::abs(dx)<1 ? 0.5*dx*dx : std::abs(dx)-0.5; double ldy = std::abs(dy)<1 ? 0.5*dy*dy : std::abs(dy)-0.5; double ldw = std::abs(dw)<1 ? 0.5*dw*dw : std::abs(dw)-0.5; double ldh = std::abs(dh)<1 ? 0.5*dh*dh : std::abs(dh)-0.5; loss += options.bbr_lambda*(ldx + ldy + ldw + ldh); // now compute the derivatives of the smoothed L1 loss ldx = put_in_range(-1,1, dx); ldy = put_in_range(-1,1, dy); ldw = put_in_range(-1,1, dw); ldh = put_in_range(-1,1, dh); // also smoothed L1 gradient goes to gradient output g[dets[i].tensor_offset_dx] += scale*options.bbr_lambda*ldx; g[dets[i].tensor_offset_dy] += scale*options.bbr_lambda*ldy; g[dets[i].tensor_offset_dw] += scale*options.bbr_lambda*ldw; g[dets[i].tensor_offset_dh] += scale*options.bbr_lambda*ldh; } } else { final_dets.push_back(dets[i]); loss += options.loss_per_false_alarm; } } } else if (!overlaps_ignore_box(*truth, dets[i].rect)) { // didn't hit anything final_dets.push_back(dets[i]); loss += options.loss_per_false_alarm; } } for (auto&& x : final_dets) { loss += out_data[x.tensor_offset]; g[x.tensor_offset] += scale; } ++truth; g += output_tensor.k()*output_tensor.nr()*output_tensor.nc(); out_data += output_tensor.k()*output_tensor.nr()*output_tensor.nc(); } // END for (long i = 0; i < output_tensor.num_samples(); ++i) // Here we scale the loss so that it's roughly equal to the number of mistakes // in an image. Note that this scaling is different than the scaling we // applied to the gradient but it doesn't matter since the loss value isn't // used to update parameters. It's used only for display and to check if we // have converged. So it doesn't matter that they are scaled differently and // this way the loss that is displayed is readily interpretable to the user. return loss/output_tensor.num_samples(); } friend void serialize(const loss_mmod_& item, std::ostream& out) { serialize("loss_mmod_", out); serialize(item.options, out); } friend void deserialize(loss_mmod_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_mmod_") throw serialization_error("Unexpected version found while deserializing dlib::loss_mmod_."); deserialize(item.options, in); } friend std::ostream& operator<<(std::ostream& out, const loss_mmod_& item) { out << "loss_mmod\t ("; auto& opts = item.options; out << "detector_windows:(" << opts.detector_windows << ")"; out << ", loss per FA:" << opts.loss_per_false_alarm; out << ", loss per miss:" << opts.loss_per_missed_target; out << ", truth match IOU thresh:" << opts.truth_match_iou_threshold; out << ", use_bounding_box_regression:" << opts.use_bounding_box_regression; if (opts.use_bounding_box_regression) out << ", bbr_lambda:" << opts.bbr_lambda; out << ", overlaps_nms:("<\n"; } private: template void tensor_to_dets ( const tensor& input_tensor, const tensor& output_tensor, long i, std::vector& dets_accum, double adjust_threshold, const net_type& net ) const { DLIB_CASSERT(net.sample_expansion_factor() == 1,net.sample_expansion_factor()); if (options.use_bounding_box_regression) { DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size()*5); } else { DLIB_CASSERT(output_tensor.k() == (long)options.detector_windows.size()); } const float* out_data = output_tensor.host() + output_tensor.k()*output_tensor.nr()*output_tensor.nc()*i; // scan the final layer and output the positive scoring locations dets_accum.clear(); for (long k = 0; k < (long)options.detector_windows.size(); ++k) { for (long r = 0; r < output_tensor.nr(); ++r) { for (long c = 0; c < output_tensor.nc(); ++c) { double score = out_data[(k*output_tensor.nr() + r)*output_tensor.nc() + c]; if (score > adjust_threshold) { dpoint p = output_tensor_to_input_tensor(net, point(c,r)); drectangle rect = centered_drect(p, options.detector_windows[k].width, options.detector_windows[k].height); rect = input_layer(net).tensor_space_to_image_space(input_tensor,rect); dets_accum.push_back(intermediate_detection(rect, score, (k*output_tensor.nr() + r)*output_tensor.nc() + c, k)); if (options.use_bounding_box_regression) { const auto offset = options.detector_windows.size() + k*4; dets_accum.back().tensor_offset_dx = ((offset+0)*output_tensor.nr() + r)*output_tensor.nc() + c; dets_accum.back().tensor_offset_dy = ((offset+1)*output_tensor.nr() + r)*output_tensor.nc() + c; dets_accum.back().tensor_offset_dw = ((offset+2)*output_tensor.nr() + r)*output_tensor.nc() + c; dets_accum.back().tensor_offset_dh = ((offset+3)*output_tensor.nr() + r)*output_tensor.nc() + c; // apply BBR to dets_accum.back() double dx = out_data[dets_accum.back().tensor_offset_dx]; double dy = out_data[dets_accum.back().tensor_offset_dy]; double dw = out_data[dets_accum.back().tensor_offset_dw]; double dh = out_data[dets_accum.back().tensor_offset_dh]; dw = std::exp(dw); dh = std::exp(dh); double w = rect.width()-1; double h = rect.height()-1; rect = translate_rect(rect, dpoint(dx*w,dy*h)); rect = centered_drect(rect, w*dw+1, h*dh+1); dets_accum.back().rect_bbr = rect; } } } } } std::sort(dets_accum.rbegin(), dets_accum.rend()); } size_t find_best_detection_window ( rectangle rect, const std::string& label, use_image_pyramid assume_image_pyramid ) const { if (assume_image_pyramid == use_image_pyramid::yes) { rect = move_rect(set_rect_area(rect, 400*400), point(0,0)); } else { rect = rectangle(rect.width(), rect.height()); } // Figure out which detection window in options.detector_windows is most similar to rect // (in terms of aspect ratio, if assume_image_pyramid == use_image_pyramid::yes). size_t best_i = 0; double best_ratio_diff = -std::numeric_limits::infinity(); for (size_t i = 0; i < options.detector_windows.size(); ++i) { if (options.detector_windows[i].label != label) continue; rectangle det_window; if (options.assume_image_pyramid == use_image_pyramid::yes) { det_window = centered_rect(point(0,0), options.detector_windows[i].width, options.detector_windows[i].height); det_window = move_rect(set_rect_area(det_window, 400*400), point(0,0)); } else { det_window = rectangle(options.detector_windows[i].width, options.detector_windows[i].height); } double iou = box_intersection_over_union(rect, det_window); if (iou > best_ratio_diff) { best_ratio_diff = iou; best_i = i; } } return best_i; } template bool image_rect_to_feat_coord ( point& tensor_p, const tensor& input_tensor, const rectangle& rect, const std::string& label, const net_type& net, size_t& det_idx, use_image_pyramid assume_image_pyramid ) const { if (!input_layer(net).image_contained_point(input_tensor,center(rect))) { std::ostringstream sout; sout << "Encountered a truth rectangle located at " << rect << " that is outside the image." << std::endl; sout << "The center of each truth rectangle must be within the image." << std::endl; throw impossible_labeling_error(sout.str()); } det_idx = find_best_detection_window(rect,label,assume_image_pyramid); double scale = 1.0; if (options.assume_image_pyramid == use_image_pyramid::yes) { // Compute the scale we need to be at to get from rect to our detection window. // Note that we compute the scale as the max of two numbers. It doesn't // actually matter which one we pick, because if they are very different then // it means the box can't be matched by the sliding window. But picking the // max causes the right error message to be selected in the logic below. scale = std::max(options.detector_windows[det_idx].width/(double)rect.width(), options.detector_windows[det_idx].height/(double)rect.height()); } else { // We don't want invariance to scale. scale = 1.0; } const rectangle mapped_rect = input_layer(net).image_space_to_tensor_space(input_tensor, std::min(1.0,scale), rect); // compute the detection window that we would use at this position. tensor_p = center(mapped_rect); rectangle det_window = centered_rect(tensor_p, options.detector_windows[det_idx].width,options.detector_windows[det_idx].height); det_window = input_layer(net).tensor_space_to_image_space(input_tensor, det_window); // make sure the rect can actually be represented by the image pyramid we are // using. if (box_intersection_over_union(rect, det_window) <= options.truth_match_iou_threshold) { std::cout << "Warning, ignoring object. We encountered a truth rectangle with a width and height of " << rect.width() << " and " << rect.height() << ". "; std::cout << "The image pyramid and sliding windows can't output a rectangle of this shape. "; const double detector_area = options.detector_windows[det_idx].width*options.detector_windows[det_idx].height; if (mapped_rect.area()/detector_area <= options.truth_match_iou_threshold) { std::cout << "This is because the rectangle is smaller than the best matching detection window, which has a width "; std::cout << "and height of " << options.detector_windows[det_idx].width << " and " << options.detector_windows[det_idx].height << "." << std::endl; } else { std::cout << "This is either because (1) the final layer's features have too large of a stride across the image, limiting the possible locations the sliding window can search "; std::cout << "or (2) because the rectangle's aspect ratio is too different from the best matching detection window, "; std::cout << "which has a width and height of " << options.detector_windows[det_idx].width << " and " << options.detector_windows[det_idx].height << "." << std::endl; } return true; } // now map through the CNN to the output layer. tensor_p = input_tensor_to_output_tensor(net,tensor_p); const tensor& output_tensor = net.get_output(); if (!get_rect(output_tensor).contains(tensor_p)) { std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << rect << " that is too close to the edge "; std::cout << "of the image to be captured by the CNN features." << std::endl; return true; } return false; } bool overlaps_ignore_box ( const std::vector& boxes, const rectangle& rect ) const { for (auto&& b : boxes) { if (b.ignore && options.overlaps_ignore(b, rect)) return true; } return false; } std::pair find_best_match( const std::vector& boxes, const std::vector& hit_truth_table, const rectangle& rect, const std::string& label ) const { double match = 0; unsigned int best_idx = 0; for (int allow_duplicate_hit = 0; allow_duplicate_hit <= 1 && match == 0; ++allow_duplicate_hit) { for (unsigned long i = 0; i < boxes.size(); ++i) { if (boxes[i].ignore || boxes[i].label != label) continue; if (!allow_duplicate_hit && hit_truth_table[i]) continue; const double new_match = box_intersection_over_union(rect, boxes[i]); if (new_match > match) { match = new_match; best_idx = i; } } } return std::make_pair(match,best_idx); } std::pair find_best_match( const std::vector& boxes, const rectangle& rect, const size_t excluded_idx ) const { double match = 0; unsigned int best_idx = 0; for (unsigned long i = 0; i < boxes.size(); ++i) { if (boxes[i].ignore || excluded_idx == i) continue; const double new_match = box_intersection_over_union(rect, boxes[i]); if (new_match > match) { match = new_match; best_idx = i; } } return std::make_pair(match,best_idx); } template inline bool overlaps_any_box_nms ( const std::vector& rects, const rectangle& rect ) const { for (auto&& r : rects) { if (options.overlaps_nms(r.rect, rect)) return true; } return false; } mmod_options options; }; template using loss_mmod = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_metric_ { public: typedef unsigned long training_label_type; typedef matrix output_label_type; loss_metric_() = default; loss_metric_( float margin_, float dist_thresh_ ) : margin(margin_), dist_thresh(dist_thresh_) { DLIB_CASSERT(margin_ > 0); DLIB_CASSERT(dist_thresh_ > 0); } template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1); const float* p = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { *iter = mat(p,output_tensor.k(),1); ++iter; p += output_tensor.k(); } } float get_margin() const { return margin; } float get_distance_threshold() const { return dist_thresh; } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1); DLIB_CASSERT(grad.nr() == 1 && grad.nc() == 1); temp.set_size(output_tensor.num_samples(), output_tensor.num_samples()); grad_mul.copy_size(temp); tt::gemm(0, temp, 1, output_tensor, false, output_tensor, true); std::vector temp_threshs; const float* d = temp.host(); double loss = 0; double num_pos_samps = 0.0001; double num_neg_samps = 0.0001; for (long r = 0; r < temp.num_samples(); ++r) { auto xx = d[r*temp.num_samples() + r]; const auto x_label = *(truth + r); for (long c = r+1; c < temp.num_samples(); ++c) { const auto y_label = *(truth + c); if (x_label == y_label) { ++num_pos_samps; } else { ++num_neg_samps; // Figure out what distance threshold, when applied to the negative pairs, // causes there to be an equal number of positive and negative pairs. auto yy = d[c*temp.num_samples() + c]; auto xy = d[r*temp.num_samples() + c]; // compute the distance between x and y samples. auto d2 = xx + yy - 2*xy; if (d2 < 0) d2 = 0; temp_threshs.push_back(d2); } } } // The whole objective function is multiplied by this to scale the loss // relative to the number of things in the mini-batch. const double scale = 0.5/num_pos_samps; DLIB_CASSERT(num_pos_samps>=1, "Make sure each mini-batch contains both positive pairs and negative pairs"); DLIB_CASSERT(num_neg_samps>=1, "Make sure each mini-batch contains both positive pairs and negative pairs"); std::sort(temp_threshs.begin(), temp_threshs.end()); const float neg_thresh = std::sqrt(temp_threshs[std::min(num_pos_samps,num_neg_samps)-1]); // loop over all the pairs of training samples and compute the loss and // gradients. Note that we only use the hardest negative pairs and that in // particular we pick the number of negative pairs equal to the number of // positive pairs so everything is balanced. float* gm = grad_mul.host(); for (long r = 0; r < temp.num_samples(); ++r) { gm[r*temp.num_samples() + r] = 0; const auto x_label = *(truth + r); auto xx = d[r*temp.num_samples() + r]; for (long c = 0; c < temp.num_samples(); ++c) { if (r==c) continue; const auto y_label = *(truth + c); auto yy = d[c*temp.num_samples() + c]; auto xy = d[r*temp.num_samples() + c]; // compute the distance between x and y samples. auto d2 = xx + yy - 2*xy; if (d2 <= 0) d2 = 0; else d2 = std::sqrt(d2); // It should be noted that the derivative of length(x-y) with respect // to the x vector is the unit vector (x-y)/length(x-y). If you stare // at the code below long enough you will see that it's just an // application of this formula. if (x_label == y_label) { // Things with the same label should have distances < dist_thresh between // them. If not then we experience non-zero loss. if (d2 < dist_thresh-margin) { gm[r*temp.num_samples() + c] = 0; } else { loss += scale*(d2 - (dist_thresh-margin)); gm[r*temp.num_samples() + r] += scale/d2; gm[r*temp.num_samples() + c] = -scale/d2; } } else { // Things with different labels should have distances > dist_thresh between // them. If not then we experience non-zero loss. if (d2 > dist_thresh+margin || d2 > neg_thresh) { gm[r*temp.num_samples() + c] = 0; } else { loss += scale*((dist_thresh+margin) - d2); // don't divide by zero (or a really small number) d2 = std::max(d2, 0.001f); gm[r*temp.num_samples() + r] -= scale/d2; gm[r*temp.num_samples() + c] = scale/d2; } } } } tt::gemm(0, grad, 1, grad_mul, false, output_tensor, false); return loss; } friend void serialize(const loss_metric_& item, std::ostream& out) { serialize("loss_metric_2", out); serialize(item.margin, out); serialize(item.dist_thresh, out); } friend void deserialize(loss_metric_& item, std::istream& in) { std::string version; deserialize(version, in); if (version == "loss_metric_") { // These values used to be hard coded, so for this version of the metric // learning loss we just use these values. item.margin = 0.1f; item.dist_thresh = 0.75f; return; } else if (version == "loss_metric_2") { deserialize(item.margin, in); deserialize(item.dist_thresh, in); } else { throw serialization_error("Unexpected version found while deserializing dlib::loss_metric_. Instead found " + version); } } friend std::ostream& operator<<(std::ostream& out, const loss_metric_& item ) { out << "loss_metric (margin="<\n"; } private: float margin = 0.04f; float dist_thresh = 0.6f; // These variables are only here to avoid being reallocated over and over in // compute_loss_value_and_gradient() mutable resizable_tensor temp, grad_mul; }; template using loss_metric = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_ranking_ { public: typedef float training_label_type; // nominally +1/-1 typedef float output_label_type; // ranking score template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { DLIB_CASSERT(sub.sample_expansion_factor() == 1); const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 && output_tensor.k() == 1); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { *iter++ = out_data[i]; } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 && output_tensor.k() == 1); DLIB_CASSERT(grad.nr() == 1 && grad.nc() == 1 && grad.k() == 1); std::vector rel_scores; std::vector nonrel_scores; std::vector rel_idx, nonrel_idx; const float* out_data = output_tensor.host(); float* g = grad.host_write_only(); for (long i = 0; i < output_tensor.num_samples(); ++i) { const float y = *truth++; if (y > 0) { rel_scores.push_back(out_data[i]-y); rel_idx.push_back(i); } else if (y < 0) { nonrel_scores.push_back(out_data[i]-y); nonrel_idx.push_back(i); } else { g[i] = 0; } } std::vector rel_counts; std::vector nonrel_counts; count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts); const unsigned long total_pairs = rel_scores.size()*nonrel_scores.size(); DLIB_CASSERT(total_pairs > 0, "You can't give a ranking mini-batch that contains only one class. Both classes must be represented."); const double scale = 1.0/total_pairs; double loss = 0; for (unsigned long k = 0; k < rel_counts.size(); ++k) { loss -= rel_counts[k]*rel_scores[k]; g[rel_idx[k]] = -1.0*rel_counts[k]*scale; } for (unsigned long k = 0; k < nonrel_counts.size(); ++k) { loss += nonrel_counts[k]*nonrel_scores[k]; g[nonrel_idx[k]] = nonrel_counts[k]*scale; } return loss*scale; } friend void serialize(const loss_ranking_& , std::ostream& out) { serialize("loss_ranking_", out); } friend void deserialize(loss_ranking_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_ranking_") throw serialization_error("Unexpected version found while deserializing dlib::loss_ranking_."); } friend std::ostream& operator<<(std::ostream& out, const loss_ranking_& ) { out << "loss_ranking"; return out; } friend void to_xml(const loss_ranking_& /*item*/, std::ostream& out) { out << "\n"; } }; template using loss_ranking = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_mean_squared_ { public: typedef float training_label_type; typedef float output_label_type; template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { DLIB_CASSERT(sub.sample_expansion_factor() == 1); const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 && output_tensor.k() == 1); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { *iter++ = out_data[i]; } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 && output_tensor.k() == 1); DLIB_CASSERT(grad.nr() == 1 && grad.nc() == 1 && grad.k() == 1); // The loss we output is the average loss over the mini-batch. const double scale = 1.0/output_tensor.num_samples(); double loss = 0; float* g = grad.host_write_only(); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { const float y = *truth++; const float temp1 = y - out_data[i]; const float temp2 = scale*temp1; loss += temp2*temp1; g[i] = -temp2; } return loss; } friend void serialize(const loss_mean_squared_& , std::ostream& out) { serialize("loss_mean_squared_", out); } friend void deserialize(loss_mean_squared_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_mean_squared_") throw serialization_error("Unexpected version found while deserializing dlib::loss_mean_squared_."); } friend std::ostream& operator<<(std::ostream& out, const loss_mean_squared_& ) { out << "loss_mean_squared"; return out; } friend void to_xml(const loss_mean_squared_& /*item*/, std::ostream& out) { out << "\n"; } }; template using loss_mean_squared = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_epsilon_insensitive_ { public: typedef float training_label_type; typedef float output_label_type; loss_epsilon_insensitive_() = default; loss_epsilon_insensitive_(double eps) : eps(eps) { DLIB_CASSERT(eps >= 0, "You can't set a negative error epsilon."); } double get_epsilon () const { return eps; } void set_epsilon(double e) { DLIB_CASSERT(e >= 0, "You can't set a negative error epsilon."); eps = e; } template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { DLIB_CASSERT(sub.sample_expansion_factor() == 1); const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 && output_tensor.k() == 1); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { *iter++ = out_data[i]; } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1 && output_tensor.k() == 1); DLIB_CASSERT(grad.nr() == 1 && grad.nc() == 1 && grad.k() == 1); // The loss we output is the average loss over the mini-batch. const double scale = 1.0/output_tensor.num_samples(); double loss = 0; float* g = grad.host_write_only(); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { const float y = *truth++; const float err = out_data[i]-y; if (err > eps) { loss += scale*(err-eps); g[i] = scale; } else if (err < -eps) { loss += scale*(eps-err); g[i] = -scale; } } return loss; } friend void serialize(const loss_epsilon_insensitive_& item, std::ostream& out) { serialize("loss_epsilon_insensitive_", out); serialize(item.eps, out); } friend void deserialize(loss_epsilon_insensitive_& item, std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_epsilon_insensitive_") throw serialization_error("Unexpected version found while deserializing dlib::loss_epsilon_insensitive_."); deserialize(item.eps, in); } friend std::ostream& operator<<(std::ostream& out, const loss_epsilon_insensitive_& item) { out << "loss_epsilon_insensitive epsilon: " << item.eps; return out; } friend void to_xml(const loss_epsilon_insensitive_& item, std::ostream& out) { out << "\n"; } private: double eps = 1; }; template using loss_epsilon_insensitive = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_mean_squared_multioutput_ { public: typedef matrix training_label_type; typedef matrix output_label_type; template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { DLIB_CASSERT(sub.sample_expansion_factor() == 1); const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1) DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { *iter++ = mat(out_data, output_tensor.k(), 1); out_data += output_tensor.k(); } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.nr() == 1 && output_tensor.nc() == 1); DLIB_CASSERT(grad.nr() == 1 && grad.nc() == 1); DLIB_CASSERT(grad.k() == output_tensor.k()); const long k = output_tensor.k(); for (long idx = 0; idx < output_tensor.num_samples(); ++idx) { const_label_iterator truth_matrix_ptr = (truth + idx); DLIB_CASSERT((*truth_matrix_ptr).nr() == k && (*truth_matrix_ptr).nc() == 1); } // The loss we output is the average loss over the mini-batch. const double scale = 1.0/output_tensor.num_samples(); double loss = 0; float* g = grad.host_write_only(); const float* out_data = output_tensor.host(); matrix ytrue; for (long i = 0; i < output_tensor.num_samples(); ++i) { ytrue = *truth++; for (long j = 0; j < output_tensor.k(); ++j) { const float y = ytrue(j, 0); const float temp1 = y - *out_data++; const float temp2 = scale*temp1; loss += temp2*temp1; *g = -temp2; ++g; } } return loss; } friend void serialize(const loss_mean_squared_multioutput_& , std::ostream& out) { serialize("loss_mean_squared_multioutput_", out); } friend void deserialize(loss_mean_squared_multioutput_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_mean_squared_multioutput_") throw serialization_error("Unexpected version found while deserializing dlib::loss_mean_squared_."); } friend std::ostream& operator<<(std::ostream& out, const loss_mean_squared_multioutput_& ) { out << "loss_mean_squared_multioutput"; return out; } friend void to_xml(const loss_mean_squared_multioutput_& /*item*/, std::ostream& out) { out << "\n"; } }; template using loss_mean_squared_multioutput = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_binary_log_per_pixel_ { public: typedef matrix training_label_type; typedef matrix output_label_type; template < typename SUB_TYPE, typename label_iterator > static void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) { DLIB_CASSERT(sub.sample_expansion_factor() == 1); const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(output_tensor.k() == 1); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); const float* const out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter) { iter->set_size(output_tensor.nr(), output_tensor.nc()); for (long r = 0; r < output_tensor.nr(); ++r) { for (long c = 0; c < output_tensor.nc(); ++c) { iter->operator()(r, c) = out_data[tensor_index(output_tensor, i, 0, r, c)]; } } } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.k() == 1); DLIB_CASSERT(output_tensor.nr() == grad.nr() && output_tensor.nc() == grad.nc() && output_tensor.k() == grad.k()); for (long idx = 0; idx < output_tensor.num_samples(); ++idx) { const_label_iterator truth_matrix_ptr = (truth + idx); DLIB_CASSERT(truth_matrix_ptr->nr() == output_tensor.nr() && truth_matrix_ptr->nc() == output_tensor.nc(), "truth size = " << truth_matrix_ptr->nr() << " x " << truth_matrix_ptr->nc() << ", " "output size = " << output_tensor.nr() << " x " << output_tensor.nc()); } double loss; #ifdef DLIB_USE_CUDA cuda_compute(truth, output_tensor, grad, loss); #else cpu_compute(truth, output_tensor, grad, loss); #endif return loss; } friend void serialize(const loss_binary_log_per_pixel_& , std::ostream& out) { serialize("loss_binary_log_per_pixel_", out); } friend void deserialize(loss_binary_log_per_pixel_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_binary_log_per_pixel_") throw serialization_error("Unexpected version found while deserializing dlib::loss_binary_log_per_pixel_."); } friend std::ostream& operator<<(std::ostream& out, const loss_binary_log_per_pixel_& ) { out << "loss_binary_log_per_pixel"; return out; } friend void to_xml(const loss_binary_log_per_pixel_& /*item*/, std::ostream& out) { out << "\n"; } private: #ifdef DLIB_USE_CUDA cuda::compute_loss_binary_log_per_pixel cuda_compute; #else cpu::compute_loss_binary_log_per_pixel cpu_compute; #endif }; template using loss_binary_log_per_pixel = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_multiclass_log_per_pixel_ { public: // In semantic segmentation, if you don't know the ground-truth of some pixel, // set the label of that pixel to this value. When you do so, the pixel will be // ignored when computing gradients. static const uint16_t label_to_ignore = std::numeric_limits::max(); // In semantic segmentation, 65535 classes ought to be enough for anybody. typedef matrix training_label_type; typedef matrix output_label_type; template < typename SUB_TYPE, typename label_iterator > static void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) { DLIB_CASSERT(sub.sample_expansion_factor() == 1); const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(output_tensor.k() >= 1); // Note that output_tensor.k() should match the number of labels. DLIB_CASSERT(output_tensor.k() < std::numeric_limits::max()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); const float* const out_data = output_tensor.host(); // The index of the largest output for each element is the label. const auto find_label = [&](long sample, long r, long c) { uint16_t label = 0; float max_value = out_data[tensor_index(output_tensor, sample, 0, r, c)]; for (long k = 1; k < output_tensor.k(); ++k) { const float value = out_data[tensor_index(output_tensor, sample, k, r, c)]; if (value > max_value) { label = static_cast(k); max_value = value; } } return label; }; for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter) { iter->set_size(output_tensor.nr(), output_tensor.nc()); for (long r = 0; r < output_tensor.nr(); ++r) { for (long c = 0; c < output_tensor.nc(); ++c) { // The index of the largest output for this element is the label. iter->operator()(r, c) = find_label(i, r, c); } } } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.k() >= 1); DLIB_CASSERT(output_tensor.k() < std::numeric_limits::max()); DLIB_CASSERT(output_tensor.nr() == grad.nr() && output_tensor.nc() == grad.nc() && output_tensor.k() == grad.k()); for (long idx = 0; idx < output_tensor.num_samples(); ++idx) { const_label_iterator truth_matrix_ptr = (truth + idx); DLIB_CASSERT(truth_matrix_ptr->nr() == output_tensor.nr() && truth_matrix_ptr->nc() == output_tensor.nc(), "truth size = " << truth_matrix_ptr->nr() << " x " << truth_matrix_ptr->nc() << ", " "output size = " << output_tensor.nr() << " x " << output_tensor.nc()); } double loss; #ifdef DLIB_USE_CUDA cuda_compute(truth, output_tensor, grad, loss); #else cpu_compute(truth, output_tensor, grad, loss); #endif return loss; } friend void serialize(const loss_multiclass_log_per_pixel_& , std::ostream& out) { serialize("loss_multiclass_log_per_pixel_", out); } friend void deserialize(loss_multiclass_log_per_pixel_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_multiclass_log_per_pixel_") throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_per_pixel_."); } friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_per_pixel_& ) { out << "loss_multiclass_log_per_pixel"; return out; } friend void to_xml(const loss_multiclass_log_per_pixel_& /*item*/, std::ostream& out) { out << "\n"; } private: #ifdef DLIB_USE_CUDA cuda::compute_loss_multiclass_log_per_pixel cuda_compute; #else cpu::compute_loss_multiclass_log_per_pixel cpu_compute; #endif }; template using loss_multiclass_log_per_pixel = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_multiclass_log_per_pixel_weighted_ { public: typedef dlib::weighted_label weighted_label; typedef matrix training_label_type; typedef matrix output_label_type; template < typename SUB_TYPE, typename label_iterator > static void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) { loss_multiclass_log_per_pixel_::to_label(input_tensor, sub, iter); } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.k() >= 1); DLIB_CASSERT(output_tensor.k() < std::numeric_limits::max()); DLIB_CASSERT(output_tensor.nr() == grad.nr() && output_tensor.nc() == grad.nc() && output_tensor.k() == grad.k()); for (long idx = 0; idx < output_tensor.num_samples(); ++idx) { const_label_iterator truth_matrix_ptr = (truth + idx); DLIB_CASSERT(truth_matrix_ptr->nr() == output_tensor.nr() && truth_matrix_ptr->nc() == output_tensor.nc(), "truth size = " << truth_matrix_ptr->nr() << " x " << truth_matrix_ptr->nc() << ", " "output size = " << output_tensor.nr() << " x " << output_tensor.nc()); } double loss; #ifdef DLIB_USE_CUDA cuda_compute(truth, output_tensor, grad, loss); #else cpu_compute(truth, output_tensor, grad, loss); #endif return loss; } friend void serialize(const loss_multiclass_log_per_pixel_weighted_& , std::ostream& out) { serialize("loss_multiclass_log_per_pixel_weighted_", out); } friend void deserialize(loss_multiclass_log_per_pixel_weighted_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_multiclass_log_per_pixel_weighted_") throw serialization_error("Unexpected version found while deserializing dlib::loss_multiclass_log_per_pixel_weighted_."); } friend std::ostream& operator<<(std::ostream& out, const loss_multiclass_log_per_pixel_weighted_& ) { out << "loss_multiclass_log_per_pixel_weighted"; return out; } friend void to_xml(const loss_multiclass_log_per_pixel_weighted_& /*item*/, std::ostream& out) { out << "\n"; } private: #ifdef DLIB_USE_CUDA cuda::compute_loss_multiclass_log_per_pixel_weighted cuda_compute; #else cpu::compute_loss_multiclass_log_per_pixel_weighted cpu_compute; #endif }; template using loss_multiclass_log_per_pixel_weighted = add_loss_layer; // ---------------------------------------------------------------------------------------- class loss_mean_squared_per_pixel_ { public: typedef matrix training_label_type; typedef matrix output_label_type; template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { DLIB_CASSERT(sub.sample_expansion_factor() == 1); const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(output_tensor.k() == 1, "output k = " << output_tensor.k()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter) { iter->set_size(output_tensor.nr(), output_tensor.nc()); for (long r = 0; r < output_tensor.nr(); ++r) { for (long c = 0; c < output_tensor.nc(); ++c) { iter->operator()(r, c) = out_data[tensor_index(output_tensor, i, 0, r, c)]; } } } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples() % sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.k() >= 1); DLIB_CASSERT(output_tensor.k() < std::numeric_limits::max()); DLIB_CASSERT(output_tensor.nr() == grad.nr() && output_tensor.nc() == grad.nc() && output_tensor.k() == grad.k()); for (long idx = 0; idx < output_tensor.num_samples(); ++idx) { const_label_iterator truth_matrix_ptr = (truth + idx); DLIB_CASSERT(truth_matrix_ptr->nr() == output_tensor.nr() && truth_matrix_ptr->nc() == output_tensor.nc(), "truth size = " << truth_matrix_ptr->nr() << " x " << truth_matrix_ptr->nc() << ", " "output size = " << output_tensor.nr() << " x " << output_tensor.nc()); } // The loss we output is the average loss over the mini-batch, and also over each element of the matrix output. const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.nr() * output_tensor.nc()); double loss = 0; float* const g = grad.host(); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth) { for (long r = 0; r < output_tensor.nr(); ++r) { for (long c = 0; c < output_tensor.nc(); ++c) { const float y = truth->operator()(r, c); const size_t idx = tensor_index(output_tensor, i, 0, r, c); const float temp1 = y - out_data[idx]; const float temp2 = scale*temp1; loss += temp2*temp1; g[idx] = -temp2; } } } return loss; } friend void serialize(const loss_mean_squared_per_pixel_& , std::ostream& out) { serialize("loss_mean_squared_per_pixel_", out); } friend void deserialize(loss_mean_squared_per_pixel_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_mean_squared_per_pixel_") throw serialization_error("Unexpected version found while deserializing dlib::loss_mean_squared_per_pixel_."); } friend std::ostream& operator<<(std::ostream& out, const loss_mean_squared_per_pixel_& ) { out << "loss_mean_squared_per_pixel"; return out; } friend void to_xml(const loss_mean_squared_per_pixel_& /*item*/, std::ostream& out) { out << "\n"; } }; template using loss_mean_squared_per_pixel = add_loss_layer; // ---------------------------------------------------------------------------------------- template class loss_mean_squared_per_channel_and_pixel_ { public: typedef std::array, _num_channels> training_label_type; typedef std::array, _num_channels> output_label_type; template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { DLIB_CASSERT(sub.sample_expansion_factor() == 1); const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(output_tensor.k() == _num_channels, "output k = " << output_tensor.k()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter) { for (long k = 0; k < output_tensor.k(); ++k) { (*iter)[k].set_size(output_tensor.nr(), output_tensor.nc()); for (long r = 0; r < output_tensor.nr(); ++r) { for (long c = 0; c < output_tensor.nc(); ++c) { (*iter)[k].operator()(r, c) = out_data[tensor_index(output_tensor, i, k, r, c)]; } } } } } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples() % sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(output_tensor.k() == _num_channels); DLIB_CASSERT(output_tensor.nr() == grad.nr() && output_tensor.nc() == grad.nc() && output_tensor.k() == grad.k()); for (long idx = 0; idx < output_tensor.num_samples(); ++idx) { const_label_iterator truth_matrix_ptr = (truth + idx); DLIB_CASSERT((*truth_matrix_ptr).size() == _num_channels); for (long k = 0; k < output_tensor.k(); ++k) { DLIB_CASSERT((*truth_matrix_ptr)[k].nr() == output_tensor.nr() && (*truth_matrix_ptr)[k].nc() == output_tensor.nc(), "truth size = " << (*truth_matrix_ptr)[k].nr() << " x " << (*truth_matrix_ptr)[k].nc() << ", " "output size = " << output_tensor.nr() << " x " << output_tensor.nc()); } } double loss; #ifdef DLIB_USE_CUDA cuda_compute(truth, output_tensor, grad, loss); #else cpu_compute(truth, output_tensor, grad, loss); #endif return loss; } friend void serialize(const loss_mean_squared_per_channel_and_pixel_& , std::ostream& out) { serialize("loss_mean_squared_per_channel_and_pixel_", out); } friend void deserialize(loss_mean_squared_per_channel_and_pixel_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_mean_squared_per_channel_and_pixel_") throw serialization_error("Unexpected version found while deserializing dlib::loss_mean_squared_per_channel_and_pixel_."); } friend std::ostream& operator<<(std::ostream& out, const loss_mean_squared_per_channel_and_pixel_& ) { out << "loss_mean_squared_per_channel_and_pixel"; return out; } friend void to_xml(const loss_mean_squared_per_channel_and_pixel_& /*item*/, std::ostream& out) { out << "\n"; } private: #ifdef DLIB_USE_CUDA cuda::compute_loss_mean_squared_per_channel_and_pixel cuda_compute; #else cpu::compute_loss_mean_squared_per_channel_and_pixel cpu_compute; #endif }; template using loss_mean_squared_per_channel_and_pixel = add_loss_layer, SUBNET>; // ---------------------------------------------------------------------------------------- class loss_dot_ { public: typedef matrix training_label_type; typedef matrix output_label_type; template < typename SUB_TYPE, typename label_iterator > void to_label ( const tensor& input_tensor, const SUB_TYPE& sub, label_iterator iter ) const { const tensor& output_tensor = sub.get_output(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); for (long i = 0; i < output_tensor.num_samples(); ++i) *iter++ = trans(rowm(mat(output_tensor),i)); } template < typename const_label_iterator, typename SUBNET > double compute_loss_value_and_gradient ( const tensor& input_tensor, const_label_iterator truth, SUBNET& sub ) const { const tensor& output_tensor = sub.get_output(); tensor& grad = sub.get_gradient_input(); DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(input_tensor.num_samples() != 0); DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); const long network_output_dims = output_tensor.size()/output_tensor.num_samples(); // The loss we output is the average loss over the mini-batch. const double scale = 1.0/output_tensor.num_samples(); double loss = 0; float* g = grad.host(); const float* out_data = output_tensor.host(); for (long i = 0; i < output_tensor.num_samples(); ++i) { DLIB_CASSERT(truth->size() == network_output_dims, "The network must output a vector with the same dimensionality as the training labels. " << "\ntruth->size(): " << truth->size() << "\nnetwork_output_dims: " << network_output_dims); const float* t = &(*truth++)(0); for (long j = 0; j < network_output_dims; ++j) { g[j] = -t[j]*scale; loss -= out_data[j]*t[j]; } g += network_output_dims; out_data += network_output_dims; } return loss*scale; } friend void serialize(const loss_dot_& , std::ostream& out) { serialize("loss_dot_", out); } friend void deserialize(loss_dot_& , std::istream& in) { std::string version; deserialize(version, in); if (version != "loss_dot_") throw serialization_error("Unexpected version found while deserializing dlib::loss_dot_."); } friend std::ostream& operator<<(std::ostream& out, const loss_dot_& ) { out << "loss_dot"; return out; } friend void to_xml(const loss_dot_& /*item*/, std::ostream& out) { out << "\n"; } }; template using loss_dot = add_loss_layer; // ---------------------------------------------------------------------------------------- struct yolo_options { public: struct anchor_box_details { anchor_box_details() = default; anchor_box_details(unsigned long w, unsigned long h) : width(w), height(h) {} unsigned long width = 0; unsigned long height = 0; friend inline void serialize(const anchor_box_details& item, std::ostream& out) { int version = 0; serialize(version, out); serialize(item.width, out); serialize(item.height, out); } friend inline void deserialize(anchor_box_details& item, std::istream& in) { int version = 0; deserialize(version, in); deserialize(item.width, in); deserialize(item.height, in); } }; yolo_options() = default; template
\n";
        if (!out)
            throw std::ios::failure("error occurred in cpp_pretty_printer_kernel_1::print");

        unsigned long scope = 0; // counts the number of new scopes we have entered 
                        // since we were at a scope where functions can be declared

        bool recently_seen_class_keyword = false;
            // true if we have seen the keywords class, struct, or enum and
            // we have not seen any identifiers or { characters

        bool recently_seen_include = false;
            // true if we have seen the #include keyword and have not seen double
            // quoted text or >

        bool recently_seen_new_scope = false;  
            // true if we have seen the keywords class, namespace, or struct and
            // we have not seen the characters {, ), or ; since then

        bool recently_seen_paren = false;
            // true if we have seen a ) and we have only seen white_space or comments since

        bool in_initialization_list = false;
            // true if we have seen a ) followed by any white space or comments and then
            // followed by a : (in scope==0 with recently_seen_preprocessor==false) and we 
            // have not yet seen the character { or ;

        bool recently_seen_preprocessor = false;
            // true if we have seen the #pragma or #if or #define or #elif keywords and have 
            // not seen an end of line.

        bool recently_seen_extern = false;
            // true if we have seen the extern keyword and haven't seen a ; or { yet.

        unsigned long paren_count = 0; 
            // this is the number of ( we have seen minus the number of ) we have
            // seen.
            

        int type;
        stack scopes; // a stack to hold old scopes
        std::string token, temp;
        t.get_token(type,token);
        while (type != tok::END_OF_FILE)
        {
            switch (type)
            {
            case tok::IDENTIFIER: // ------------------------------------------
                if ( recently_seen_class_keyword)
                {
                    // this might be a class name so check if there is a 
                    // ; or identifier or * or & coming up.
                    type = t.peek_type();
                    temp.clear();
                    if (type == tok::WHITE_SPACE)
                    {
                        t.get_token(type,temp);
                        if (temp.find_first_of("\n\r") != std::string::npos)
                            recently_seen_preprocessor = false;
                    }
                    if (t.peek_type() != tok::IDENTIFIER &&
                        t.peek_token() != "*" && t.peek_token() != "&")
                    {
                        // this is the name of a class or struct in a class or
                        // struct declaration.
                        out << "" << token << "" << temp;
                    }
                    else
                    {
                        out << token << temp;
                    }
                }
                else if ( !in_initialization_list &&
                     !recently_seen_preprocessor )
                {
                    // this might be a function name so check if there is a 
                    // ( coming up.
                    type = t.peek_type();
                    temp.clear();
                    if (type == tok::WHITE_SPACE)
                    {
                        t.get_token(type,temp);
                        type = t.peek_type();
                    }
                    if (type == tok::OTHER && t.peek_token() == "(")
                    {
                        if (scope == 0 && paren_count == 0)
                        {
                            // this is a function definition or prototype
                            out << "" << token << "" << temp;
                        }
                        else
                        {
                            // this is a function call (probably) 
                            out << "" << token << "" << temp;
                        }
                    }
                    else
                    {
                        out << token << temp;
                    }
                }
                else
                {
                    out << token;
                }
                


                recently_seen_class_keyword = false;
                recently_seen_paren = false;
                break;

            case tok::KEYWORD: // ---------------------------------------------
                if (scope == 0 && token == "operator")
                {
                    // Doing this is sort of weird since operator is really a keyword
                    // but I just like how this looks.
                    out << "" << token << "";
                }
                // this isn't a keyword if it is something like #include 
                else if ( token == "true" || token == "false")
                {
                    // color 'true' and 'false' the same way we color numbers
                    out << "" << token << "";
                }
                else if (!recently_seen_include) 
                {
                    // This is a normal keyword
                    if (token == "char" || token == "unsigned" || token == "signed" ||
                        token == "short" || token == "int" || token == "long" || 
                        token == "float" || token == "double" || token == "bool" ||
                        token == "void" || token == "size_t" || token == "wchar_t")
                    {
                        out << "" << token << "";
                    }
                    else
                    {
                        out << "" << token << "";
                    }
                }
                else
                {
                    out << token;
                }

                if (token == "#include") 
                {
                    recently_seen_include = true;
                }
                else if (token == "class")
                {
                    recently_seen_new_scope = true;
                    recently_seen_class_keyword = true;
                }
                else if (token == "namespace")
                {
                    recently_seen_new_scope = true;
                }
                else if (token == "enum")
                {
                    recently_seen_class_keyword = true;
                }
                else if (token == "struct")
                {
                    recently_seen_new_scope = true;
                    recently_seen_class_keyword = true;
                }
                else if (token == "#pragma" || token == "#if" || token == "#define" || token == "#elif")
                {
                    recently_seen_preprocessor = true;
                }
                else if (token == "extern")
                {
                    recently_seen_extern = true;
                }
                recently_seen_paren = false;
                break;

            case tok::COMMENT: // ---------------------------------------------
                {
                    // if this is a special anchor comment
                    if (token.size() > 4 &&
                        token[0] == '/' &&
                        token[1] == '*' &&
                        token[2] == '!' &&
                        token[3] == 'A' &&
                        token[4] == ' '
                    )
                    {
                        temp = token;
                        std::istringstream sin(token);
                        sin >> temp;
                        sin >> temp;
                        sin.get();
                        // if there was still more stuff in the token then we are ok.
                        if (sin)
                            out << "";
                    }
                    out << "" << htmlify(token) << "";
                }
                break;

            case tok::SINGLE_QUOTED_TEXT: // ----------------------------------
                {
                    out << "" << htmlify(token) << "";
                    recently_seen_paren = false;
                }
                break;

            case tok::NUMBER: // -----------------------------------------
                {
                    out << "" << token << "";
                    recently_seen_include = false;
                }
                break;

            case tok::WHITE_SPACE: // -----------------------------------------
                {
                    out << token;
                    if (token.find_first_of("\n\r") != std::string::npos)
                        recently_seen_preprocessor = false;
                }
                break;

            case tok::DOUBLE_QUOTED_TEXT: // ----------------------------------
                {
                    if (recently_seen_include)
                    {
                        // this is the name of an included file
                        recently_seen_include = false;
                        out << "" << htmlify(token) << "";                
                    }
                    else
                    {
                        // this is just a normal quoted string
                        out << "" << htmlify(token) << "";
                    }
                    recently_seen_paren = false;
                }
                break;

            case tok::OTHER: // -----------------------------------------------               
                switch (token[0])
                {
                case '{':
                    out << "{";  
                    // if we are entering a new scope
                    if (recently_seen_new_scope || recently_seen_extern)
                    {
                        recently_seen_new_scope = false;
                        scopes.push(scope);
                        scope = 0;
                    }
                    else
                    {
                        ++scope;
                    }
                    in_initialization_list = false;
                    recently_seen_paren = false;
                    recently_seen_class_keyword = false;
                    recently_seen_extern = false;
                    break;
                case '}':
                    out << "}";
                    if (scope > 0)
                    {
                        --scope;
                    }
                    else if (scopes.size())
                    {
                        scopes.pop(scope);
                    }
                    recently_seen_paren = false;
                    break;

                case ':':
                    out << ':';
                    if (recently_seen_paren && scope == 0 && 
                        recently_seen_preprocessor == false)
                    {
                        in_initialization_list = true;
                    }
                    recently_seen_paren = false;
                    break;

                case ';': 
                    out << ';';
                    recently_seen_new_scope = false;
                    recently_seen_paren = false;
                    recently_seen_extern = false;
                    break;

                case ')':
                    out << ")";
                    recently_seen_paren = true;
                    recently_seen_new_scope = false;
                    --paren_count;
                    break;

                case '(':
                    out << "(";
                    recently_seen_paren = false;
                    ++paren_count;
                    break;

                case '>':
                    recently_seen_include = false;
                    out << ">";
                    recently_seen_paren = false;
                    break;

                case '<':
                    out << "<";
                    recently_seen_paren = false;
                    break;

                case '&':
                    out << "&";
                    recently_seen_paren = false;
                    break;

                case '=':
                case '+':
                case '-':
                case '/':
                case '*':
                case '!':
                case '|':
                case '%':
                    out << "" << token << "";
                    recently_seen_paren = false;
                    break;

                default:
                    out << token;
                    recently_seen_paren = false;
                    break;

                } // switch (token[0])
                break;

            } // switch (type)

            t.get_token(type,token);
        } // while (type != tok::END_OF_FILE)


        out << "\n